From ba936c1af991b9209f4170fd905cfedec0ae361d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 11 Oct 2023 10:03:57 +0800 Subject: [PATCH 1/4] [shardformer] shardformer support falcon [shardformer] shardformer support falcon [shardformer] shardformer support falcon [shardformer] shardformer support falcon --- colossalai/shardformer/modeling/falcon.py | 137 ++++++++++ .../shardformer/policies/auto_policy.py | 16 ++ colossalai/shardformer/policies/falcon.py | 235 ++++++++++++++++++ tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/falcon.py | 116 +++++++++ .../test_model/test_shard_falcon.py | 131 ++++++++++ 6 files changed, 636 insertions(+) create mode 100644 colossalai/shardformer/modeling/falcon.py create mode 100644 colossalai/shardformer/policies/falcon.py create mode 100644 tests/kit/model_zoo/transformers/falcon.py create mode 100644 tests/test_shardformer/test_model/test_shard_falcon.py diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py new file mode 100644 index 000000000000..bf0f63064991 --- /dev/null +++ b/colossalai/shardformer/modeling/falcon.py @@ -0,0 +1,137 @@ + +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F + +def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: + def build_falcon_alibi_tensor( + self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype + ) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size(process_group) + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) + offset = dist.get_rank(process_group) * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset : num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + return build_falcon_alibi_tensor + + +def get_tp_falcon_decoder_layer_forward(): + from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, dropout_add + + def forward( + self: FalconDecoderLayer, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + + outputs = attn_outputs[1:] + + # MLP. + mlp_output = self.mlp(mlp_layernorm_out) + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output = mlp_output + attention_output + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + return forward \ No newline at end of file diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index f3587de15f86..6e97c6a218e5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -147,6 +147,22 @@ class PolicyLocation: "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), + # Falcon + "transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation( + file_name="falcon", class_name="FalconModelPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForCausalLM": PolicyLocation( + file_name="falcon", class_name="FalconForCausalLMPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForSequenceClassification": PolicyLocation( + file_name="falcon", class_name="FalconForSequenceClassificationPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForTokenClassification": PolicyLocation( + file_name="falcon", class_name="FalconForTokenClassificationPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation( + file_name="falcon", class_name="FalconForQuestionAnsweringPolicy" + ), } _INFER_POLICY_LIST = { diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py new file mode 100644 index 000000000000..d5b4c4203ea5 --- /dev/null +++ b/colossalai/shardformer/policies/falcon.py @@ -0,0 +1,235 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn + +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from ..modeling.falcon import build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward +__all__ = [ + "FalconPolicy" +] + +class FalconPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconModel, FalconDecoderLayer + + if not self.model.config.new_decoder_architecture and self.model.config.multi_query: + warnings.warn("Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag.") + self.shard_config.enable_tensor_parallelism = False + + policy = {} + if self.shard_config.enable_tensor_parallelism: + attn_attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.num_kv_heads": self.model.config.num_kv_heads // self.shard_config.tensor_parallel_size, + } + + policy[FalconDecoderLayer] = ModulePolicyDescription( + attribute_replacement=attn_attribute_replacement, + method_replacement={ + "forward": get_tp_falcon_decoder_layer_forward() + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row + ), + ] + ) + + + policy[FalconModel] = ModulePolicyDescription( + attribute_replacement={ + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + method_replacement={ + "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # handle falcon model + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=FalconModel, + ) + + # handle falcon decoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_attn", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="ln_mlp", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=FalconDecoderLayer, + ) + + return policy + + def postprocess(self): + return self.model + +class FalconModelPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + return policy + +class FalconForCausalLMPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForCausalLM + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=FalconForCausalLM, + ) + return policy + +class FalconForSequenceClassificationPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForSequenceClassification + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=FalconForSequenceClassification, + ) + return policy + +class FalconForTokenClassificationPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForTokenClassification + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ], + policy=policy, + target_key=FalconForTokenClassification, + ) + return policy + +class FalconForQuestionAnsweringPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForQuestionAnswering + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=FalconForQuestionAnswering, + ) + return policy \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 2a492361b13b..4556acbe97c6 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -10,3 +10,4 @@ from .t5 import * from .vit import * from .whisper import * +from .falcon import * diff --git a/tests/kit/model_zoo/transformers/falcon.py b/tests/kit/model_zoo/transformers/falcon.py new file mode 100644 index 000000000000..9f1c515a378a --- /dev/null +++ b/tests/kit/model_zoo/transformers/falcon.py @@ -0,0 +1,116 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register Falcon +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([0], dtype=torch.int64) + return data + + +def data_gen_for_question_answering(): + + input_ids = torch.tensor( + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], + dtype=torch.int64, + ) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + start_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([10], dtype=torch.int64) + return dict( + input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions + ) + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_falcon_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn_for_causal_lm = lambda x: x.loss +loss_fn_for_classification = lambda x: x.loss +loss_fn_for_question_answering = lambda x: x.loss + +config = transformers.FalconConfig( + num_hidden_layers=2, num_attention_heads=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, multi_query=False, + new_decoder_architecture=True +) + +model_zoo.register( + name="transformers_falcon", + model_fn=lambda: transformers.FalconModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_falcon_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_falcon_for_causal_lm", + model_fn=lambda: transformers.FalconForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_falcon_for_sequence_classification", + model_fn=lambda: transformers.FalconForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_falcon_for_token_classification", + model_fn=lambda: transformers.FalconForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_falcon_for_question_answering", + model_fn=lambda: transformers.FalconForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py new file mode 100644 index 000000000000..323a117bc84c --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_falcon.py @@ -0,0 +1,131 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + falcon = unwrap_model(org_model, "FalconModel", "transformer") + sharded_falcon = unwrap_model(sharded_model, "FalconModel", "transformer") + + row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] + col_layer_for_check = ["h[0].self_attention.dense"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + falcon, sharded_falcon, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == "FalconModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 2e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_falcon_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_falcon(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_falcon_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_falcon(): + spawn(check_falcon, 4) + +if __name__ == "__main__": + test_falcon() From cd8385d7429d406ce3f2fa5887f814f727a17df2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Oct 2023 10:23:02 +0800 Subject: [PATCH 2/4] [shardformer] falcon support flash attention [shardformer] falcon support flash attention --- colossalai/shardformer/modeling/falcon.py | 87 ++++++++++++++++++- colossalai/shardformer/policies/falcon.py | 17 +++- tests/kit/model_zoo/transformers/falcon.py | 6 +- .../test_model/test_shard_falcon.py | 2 +- 4 files changed, 105 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index bf0f63064991..fcfac9a9b021 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -134,4 +134,89 @@ def forward( return outputs # hidden_states, present, attentions - return forward \ No newline at end of file + return forward + +def get_falcon_flash_attention_forward(): + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.falcon.modeling_falcon import FalconAttention + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def forward( + self: FalconAttention, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous() + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() + + if alibi is not None: + attention_mask_float = ( + attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta + ) + + batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1] + tgt_len = key_layer_.size()[1] + attention_mask_float = attention_mask_float.expand( + batch_size, self.num_heads, src_len, tgt_len + ).contiguous() + context_layer = me_attention( + query_layer_, + key_layer_, + value_layer_, + attn_bias=attention_mask_float, + scale=self.inv_norm_factor, + p=self.attention_dropout.p, + ) + batch_size, seq_length, _, _ = context_layer.shape + context_layer = context_layer.reshape(batch_size, seq_length, -1) + + output_tensor = self.dense(context_layer) + + return output_tensor, present + + return forward diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index d5b4c4203ea5..432d71b6c975 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,7 +7,11 @@ import colossalai.shardformer.layer as col_nn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.falcon import build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward +from ..modeling.falcon import ( + build_falcon_alibi_tensor_fn, + get_tp_falcon_decoder_layer_forward, + get_falcon_flash_attention_forward +) __all__ = [ "FalconPolicy" ] @@ -30,7 +34,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.falcon.modeling_falcon import FalconModel, FalconDecoderLayer + from transformers.models.falcon.modeling_falcon import FalconModel, FalconDecoderLayer, FalconAttention if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn("Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag.") @@ -135,6 +139,15 @@ def module_policy(self): target_key=FalconDecoderLayer, ) + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_falcon_flash_attention_forward() + }, + policy=policy, + target_key=FalconAttention, + ) + print(policy) return policy def postprocess(self): diff --git a/tests/kit/model_zoo/transformers/falcon.py b/tests/kit/model_zoo/transformers/falcon.py index 9f1c515a378a..d1909ed841b8 100644 --- a/tests/kit/model_zoo/transformers/falcon.py +++ b/tests/kit/model_zoo/transformers/falcon.py @@ -16,8 +16,8 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -33,7 +33,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data def data_gen_for_sequence_classification(): diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py index 323a117bc84c..dfde1f2cc2ab 100644 --- a/tests/test_shardformer/test_model/test_shard_falcon.py +++ b/tests/test_shardformer/test_model/test_shard_falcon.py @@ -88,7 +88,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 4, "enable_all_optimization": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, }, From 52979c0cb1db92b45ecab08bafe61e12315122fb Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Oct 2023 16:05:59 +0800 Subject: [PATCH 3/4] falcon support pipeline parallel falcon support pipeline parallel --- colossalai/shardformer/modeling/falcon.py | 552 ++++++++++++++++++ colossalai/shardformer/policies/falcon.py | 145 ++++- tests/kit/model_zoo/transformers/falcon.py | 11 +- .../test_model/test_shard_falcon.py | 85 ++- 4 files changed, 786 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index fcfac9a9b021..494c1921a2d3 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -4,10 +4,32 @@ import torch import torch.distributed as dist +from colossalai.pipeline.stage_manager import PipelineStageManager from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F +from transformers.utils import logging + +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) + +from transformers.models.falcon.modeling_falcon import ( + FalconForCausalLM, + FalconForQuestionAnswering, + FalconForSequenceClassification, + FalconForTokenClassification, + FalconModel, +) +from transformers.models.falcon.modeling_falcon import build_alibi_tensor + +from colossalai.shardformer.shard import ShardConfig + def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: def build_falcon_alibi_tensor( self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype @@ -220,3 +242,533 @@ def forward( return output_tensor, present return forward + + +class FalconPipelineForwards: + """ + This class serves as a micro library for falcon pipeline forwards. + """ + + @staticmethod + def falcon_model_forward( + self: FalconModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + else: + past_key_values = self._convert_to_rw_cache(past_key_values) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # case: First stage of training + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = inputs_embeds + + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + if self.use_alibi: + alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + else: + alibi = None + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + for i, (block, layer_past) in enumerate( + zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx + ): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if presents is not None: + presents = self._convert_cache_to_standard_format(presents, batch_size) + + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + + @staticmethod + def falcon_for_causal_lm_forward( + self: FalconForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + + @staticmethod + def falcon_for_sequence_classification_forward( + self: FalconForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + if stage_manager.is_last_stage(): + batch_size = hidden_states.shape[0] + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + + @staticmethod + def falcon_for_token_classification_forward( + self: FalconForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def falcon_for_question_answering_forward( + self: FalconForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} \ No newline at end of file diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 432d71b6c975..387c01b670ea 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -3,11 +3,13 @@ from typing import Callable, Dict, List from torch import Tensor, nn +from torch.nn import Module import colossalai.shardformer.layer as col_nn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from ..modeling.falcon import ( + FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward, get_falcon_flash_attention_forward @@ -147,19 +149,77 @@ def module_policy(self): policy=policy, target_key=FalconAttention, ) - print(policy) return policy def postprocess(self): return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "FalconModel": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + if self.model.__class__.__name__ == "FalconModel": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + + class FalconModelPolicy(FalconPolicy): def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() + + from transformers.models.falcon.modeling_falcon import FalconModel + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconModel, new_forward=FalconPipelineForwards.falcon_model_forward, policy=policy + ) return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """no shared params in falcon model""" + return [] class FalconForCausalLMPolicy(FalconPolicy): def __init__(self) -> None: @@ -179,8 +239,33 @@ def module_policy(self): policy=policy, target_key=FalconForCausalLM, ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForCausalLM, new_forward=FalconPipelineForwards.falcon_for_causal_lm_forward, policy=policy + ) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + falcon_model = self.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(falcon_model.transformer.word_embeddings.weight) == id(falcon_model.lm_head.weight): + # tie weights + return [ + { + 0: falcon_model.transformer.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: falcon_model.lm_head.weight, + } + ] + return [] + class FalconForSequenceClassificationPolicy(FalconPolicy): def __init__(self) -> None: super().__init__() @@ -199,8 +284,27 @@ def module_policy(self): policy=policy, target_key=FalconForSequenceClassification, ) + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForSequenceClassification, + new_forward=FalconPipelineForwards.falcon_for_sequence_classification_forward, + policy=policy, + ) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in falcon for sequence classification model""" + return [] + class FalconForTokenClassificationPolicy(FalconPolicy): def __init__(self) -> None: super().__init__() @@ -225,8 +329,27 @@ def module_policy(self): policy=policy, target_key=FalconForTokenClassification, ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForTokenClassification, + new_forward=FalconPipelineForwards.falcon_for_token_classification_forward, + policy=policy, + ) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in falcon for token classification model""" + return [] + class FalconForQuestionAnsweringPolicy(FalconPolicy): def __init__(self) -> None: super().__init__() @@ -245,4 +368,22 @@ def module_policy(self): policy=policy, target_key=FalconForQuestionAnswering, ) - return policy \ No newline at end of file + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForQuestionAnswering, + new_forward=FalconPipelineForwards.falcon_for_question_answering_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in falcon for question answering model""" + return [] \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/falcon.py b/tests/kit/model_zoo/transformers/falcon.py index d1909ed841b8..64e4ffe07004 100644 --- a/tests/kit/model_zoo/transformers/falcon.py +++ b/tests/kit/model_zoo/transformers/falcon.py @@ -68,8 +68,15 @@ def data_gen_for_question_answering(): loss_fn_for_question_answering = lambda x: x.loss config = transformers.FalconConfig( - num_hidden_layers=2, num_attention_heads=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, multi_query=False, - new_decoder_architecture=True + num_hidden_layers=2, + num_attention_heads=4, + vocab_size=250880, + hidden_dropout=0, + attention_dropout=0, + hidden_size=64, + multi_query=False, + new_decoder_architecture=True, + pad_token_id = -1 ) model_zoo.register( diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py index dfde1f2cc2ab..39a0aa72a61e 100644 --- a/tests/test_shardformer/test_model/test_shard_falcon.py +++ b/tests/test_shardformer/test_model/test_shard_falcon.py @@ -85,21 +85,52 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, [ { "tp_size": 2, - "pp_size": 1, + "pp_size": 2, "num_microbatches": 4, "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", - "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32" + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32" }, { "tp_size": 2, "pp_size": 1, - "num_microbatches": 4, "enable_all_optimization": True, "use_lazy_init": True, + "zero_stage": 2, "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, "zero_stage": 1, + "precision": "fp16", "initial_scale": 1, }, ], @@ -114,12 +145,52 @@ def run_falcon_test(test_config): Randomizer.reset_index() torch.cuda.empty_cache() +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) + +def run_falcon_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + def check_falcon(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_falcon_test() +def check_falcon_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_falcon_3d_test() + @pytest.mark.dist @rerun_if_address_is_in_use() @@ -127,5 +198,13 @@ def check_falcon(rank, world_size, port): def test_falcon(): spawn(check_falcon, 4) + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_falcon_3d(): + spawn(check_falcon_3d, 8) + if __name__ == "__main__": test_falcon() + test_falcon_3d() From 080e16e5aea82e7779b82b69739a602b72600756 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Oct 2023 16:40:27 +0800 Subject: [PATCH 4/4] fix fix fix --- colossalai/shardformer/README.md | 1 + colossalai/shardformer/modeling/falcon.py | 9 +++++++++ docs/source/en/features/shardformer.md | 12 ++++++++++++ docs/source/zh-Hans/features/shardformer.md | 12 ++++++++++++ tests/test_booster/test_plugin/test_gemini_plugin.py | 5 +++++ 5 files changed, 39 insertions(+) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 4bd7d5208a64..94079dbf1e0d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -127,6 +127,7 @@ We will follow this roadmap to develop Shardformer: | whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| falcon | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 494c1921a2d3..a72391372d54 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -271,7 +271,16 @@ def falcon_model_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + logger.warning_once("past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if past_key_values is None: diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index a6e32d2c05fa..8f066ab50d65 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -178,6 +178,18 @@ Model/Feature Compatibility Matrix: