From 8dee68a0a22568dbeed6d4563372b25e1e825fb0 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 13:41:14 +0800 Subject: [PATCH 1/8] bloom policy --- colossalai/shardformer/policies/bloom.py | 220 ++++++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 662ff5b4977a..7452b611bed0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,11 +1,26 @@ +import warnings +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + class BloomPolicy(Policy): @@ -110,7 +125,40 @@ def postprocess(self): class BloomModelPolicy(BloomPolicy): - pass + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.bloom.modeling_bloom import BloomModel + if self.pipeline_stage_manager: + module_policy[BloomModel] = ModulePolicyDescription( + method_replacement={"forward": partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager)}) + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if self.stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + held_layers.append(module.word_embeddings_layernorm) + + start_idx, end_idx = self.get_stage_index(layers_per_stage, self.stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + + if self.stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + + def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: + '''no shared params in bloommodel''' + pass class BloomForCausalLMPolicy(BloomPolicy): @@ -181,3 +229,173 @@ def module_policy(self): class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 pass + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # add warnings here + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + # case: First stage of training + if stage_manager.is_first_stage(): + # check input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + # extra recording tensor should be generated in the first stage + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] # source_len + + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + # causal_mask is constructed every stage and its input is passed through different stages + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + # calculate the num_layers + num_layers_per_stage = len(self.h) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_self_attentions = all_self_attentions + \ + (outputs[2 if use_cache else 1],) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + # TODO: deal with all_hidden_states, all_self_attentions, presents + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) From 9cb30051339477ffd9fd73d1687e0fa9a5c1e33e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 16:14:26 +0800 Subject: [PATCH 2/8] llama pipeline forward and tests --- .../shardformer/policies/auto_policy.py | 2 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/llama.py | 200 +++++++++++++++++- tests/test_shardformer/test_model/_utils.py | 1 + .../test_model/test_shard_llama_pipeline.py | 92 ++++++++ 5 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_llama_pipeline.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 640b61b579bd..0ad9a3e95a0e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -45,7 +45,7 @@ class PolicyLocation: # LLaMA "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaPolicy"), + PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"), "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index eacd0b449ad4..2b2c003ffb04 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -193,7 +193,7 @@ def get_held_layers(self) -> List[Module]: module = self.model stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(self.model.encoder.layer), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b10e07560d22..c6fe5c395aa7 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,11 +1,28 @@ -from typing import Dict, Union +import math +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union +import torch import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers.models.llama.modeling_llama import LlamaModel +from transformers.utils import ModelOutput, logging +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +logger = logging.get_logger(__name__) + __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -106,6 +123,39 @@ def postprocess(self): return self.model +class LlamaModelPolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: + # set None as default + module_policy[LlamaModel] = ModulePolicyDescription( + method_replacement={'forward': partial(llama_model_forward, stage_manager=self.pipeline_stage_manager)}) + return module_policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in bert model""" + return [] + + class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): @@ -144,3 +194,151 @@ def module_policy(self): } policy.update(new_item) return policy + + +def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, +): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + # TODO: we should recive the attn mask of 1st stage and send it to the other stages + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, + past_key_values_length) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx]): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {'hidden_states': hidden_states, 'attention_mask': attention_mask} diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index de8cb65d21d0..f26c6622da7e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -39,6 +39,7 @@ def build_pipeline_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, pipeline_stage_manager=stage_manager) + shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py new file mode 100644 index 000000000000..90a908e04f2a --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -0,0 +1,92 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_bert +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_llama': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + + layers_per_stage = Policy.distribute_layers(len(org_model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model(hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager, + stage_index=stage_index) + # print(output[0].shape) + assert output[0].shape == (2, 3, 128) + + torch.cuda.empty_cache() + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert(): + spawn(check_bert, 4) + + +if __name__ == "__main__": + test_bert() From a2144142d6c31eb3a9e0a73411f6596ea2a72d8d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 16:44:57 +0800 Subject: [PATCH 3/8] fix the output and attention_mask --- colossalai/shardformer/policies/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index c6fe5c395aa7..85bb8c37cc07 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -268,7 +268,7 @@ def llama_model_forward( # for the other stages, hidden_states is the output of the previous stage # TODO: we should recive the attn mask of 1st stage and send it to the other stages if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device) attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length) @@ -341,4 +341,4 @@ def custom_forward(*inputs): attentions=all_self_attns, ) # always return dict for imediate stage - return {'hidden_states': hidden_states, 'attention_mask': attention_mask} + return {'hidden_states': hidden_states} From b625f2fc16302c8e60741ef66264abd45c866bc4 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 16:51:02 +0800 Subject: [PATCH 4/8] fix name --- .../test_model/test_shard_llama_pipeline.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 90a908e04f2a..9b2a3f0ba0c1 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -26,8 +26,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [False]) @parameterize('enable_tensor_parallelism', [False]) @parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_bert -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): +#TODO: merge this into test_shard_llama +def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): DP_DIM, PP_DIM = 0, 1 DP_SIZE, PP_SIZE = 2, 2 RANK_TO_COORDINATE = { @@ -75,18 +75,18 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz torch.cuda.empty_cache() -def check_bert(rank, world_size, port): +def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_bert_test() + run_llama_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_bert(): - spawn(check_bert, 4) +def test_llama(): + spawn(check_llama, 4) if __name__ == "__main__": - test_bert() + test_llama() From 67e6a717c5b99bb84baf7de97c440db36b82e36a Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 10 Jul 2023 17:31:06 +0800 Subject: [PATCH 5/8] bind argument to policy --- .../test_model/test_shard_llama_pipeline.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 9b2a3f0ba0c1..8f27eb914e3b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -52,23 +52,15 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la if name == 'transformers_llama': org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - - layers_per_stage = Policy.distribute_layers(len(org_model.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, - attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) assert output['hidden_states'].shape == (2, 3, 128) else: attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask, - stage_manager=stage_manager, - stage_index=stage_index) + stage_manager=stage_manager) # print(output[0].shape) assert output[0].shape == (2, 3, 128) From 3f629371e36c901d842dcc959e9e4783f0a57766 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 11 Jul 2023 15:45:16 +0800 Subject: [PATCH 6/8] finish bloom model --- colossalai/shardformer/policies/bloom.py | 55 ++++++------ .../test_model/test_shard_bloom_pipeline.py | 84 +++++++++++++++++++ 2 files changed, 115 insertions(+), 24 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_bloom_pipeline.py diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7452b611bed0..5cfc1ab29edf 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -130,11 +130,17 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.bloom.modeling_bloom import BloomModel if self.pipeline_stage_manager: - module_policy[BloomModel] = ModulePolicyDescription( - method_replacement={"forward": partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager)}) + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + policy[BloomModel] = ModulePolicyDescription(method_replacement={ + "forward": + partial(bloom_model_forward, stage_manager=self.pipeline_stage_manager, stage_index=stage_index) + }) + return policy def get_held_layers(self) -> List[Module]: """ @@ -144,21 +150,21 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if self.stage_manager.is_first_stage(): + if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, self.stage_manager.stage) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.h[start_idx:end_idx]) - if self.stage_manager.is_last_stage(): + if stage_manager.is_last_stage(): held_layers.append(module.ln_f) return held_layers - def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: + def get_shared_params(self) -> List[Dict[int, Tensor]]: '''no shared params in bloommodel''' - pass + return [] class BloomForCausalLMPolicy(BloomPolicy): @@ -244,6 +250,7 @@ def bloom_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -335,12 +342,8 @@ def bloom_model_forward( past_key_values_length=past_key_values_length, ) - # calculate the num_layers - num_layers_per_stage = len(self.h) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): + start_idx, end_idx = stage_index[0], stage_index[1] + for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx])): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -389,13 +392,17 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - # attention_mask is not returned ; presents = past_key_values - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + # attention_mask is not returned ; presents = past_key_values + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for imediate stage + return {'hidden_states': hidden_states} diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py new file mode 100644 index 000000000000..31760d692694 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py @@ -0,0 +1,84 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_bloom +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + x = torch.randint(0, 1000, (2, 3)).cuda() + hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name == 'transformers_bloom': + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 64) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0].shape == (2, 3, 64) + + torch.cuda.empty_cache() + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(): + spawn(check_bloom, 4) + + +if __name__ == "__main__": + test_bloom() From 054d79aa2c513ea8e8d39a114b98dbee7dfc35de Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 12 Jul 2023 14:22:09 +0800 Subject: [PATCH 7/8] test shard gpt2 --- tests/kit/model_zoo/transformers/gpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index b9e0310780af..c627a5f973df 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -53,6 +53,7 @@ def data_gen_for_sequence_classification(): config = transformers.GPT2Config(n_layer=2, n_head=4, + n_embd=128, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, From af807c015f0f7b18063b06b8ae2bf165166f8dc9 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 12 Jul 2023 14:23:14 +0800 Subject: [PATCH 8/8] clear cache --- tests/kit/model_zoo/transformers/gpt.py | 21 ++++++++++--------- .../test_model/test_shard_gpt2.py | 1 + 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index c627a5f973df..f9a0888ff80d 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -51,16 +51,17 @@ def data_gen_for_sequence_classification(): loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() loss_fn = lambda x: x.loss -config = transformers.GPT2Config(n_layer=2, - n_head=4, - n_embd=128, - vocab_size=50258, - attn_pdrop=0, - embd_pdrop=0, - resid_pdrop=0, - summary_first_dropout=0, - hidden_dropout=0, - problem_type="single_label_classification") +config = transformers.GPT2Config( + n_layer=2, + n_head=4, + #n_embd=128, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification") # register the following models model_zoo.register(name='transformers_gpt', diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 96c4b90a8075..9e5608e7fdcf 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -70,6 +70,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('use_lazy_init', [False, True]) +@clear_cache_before_run() def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():