From cb9d34b7da30f2326c96eafd3df02f14545d0d59 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 16:28:28 +0800 Subject: [PATCH 01/14] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 3 +- .../layer/flash_attention/__init__.py | 3 + .../flash_attention/flash_attention_opt.py | 87 +++++++++++ colossalai/shardformer/policies/opt.py | 146 +++++++++++------- colossalai/shardformer/shard/shard_config.py | 12 +- .../test_flash_attention_for_opt.py | 38 +++++ tests/test_shardformer/test_model/_utils.py | 5 +- .../test_model/test_shard_opt.py | 39 ++--- 8 files changed, 241 insertions(+), 92 deletions(-) create mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py create mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py create mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..fef65fc5fa52 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D +from .flash_attention import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm @@ -8,5 +9,5 @@ __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' ] diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py new file mode 100644 index 000000000000..e3ee83b005c4 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/__init__.py @@ -0,0 +1,3 @@ +from .flash_attention_opt import opt_flash_attention_forward + +__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py new file mode 100644 index 000000000000..3f0445078357 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except ImportError: + print("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..725def7f66ab 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,4 +1,10 @@ -from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, + opt_flash_attention_forward, +) from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -29,67 +35,90 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - policy = {} - - if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + base_policy = { + OPTDecoder: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + ]) - return policy + # use flash attention + if self.shard_config.enable_flash_attention: + del base_policy[OPTAttention] + new_item = { + OPTAttention: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]), + } + base_policy.update(new_item) + + return base_policy def postprocess(self): return self.model @@ -107,12 +136,15 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } - if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..0e826d38ddd0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,14 +13,14 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None - enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -34,11 +34,8 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: @@ -50,3 +47,4 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py new file mode 100644 index 000000000000..7f196e86a253 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py @@ -0,0 +1,38 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.opt.modeling_opt import OPTAttention + +from colossalai.shardformer.layer import opt_flash_attention_forward + + +def test_flash_attention_for_opt(): + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 + + # generate input + hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + + opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, + bias=True).to("cuda") + + opt_flash_attention = deepcopy(opt_attention) + setattr(opt_flash_attention, 'forward', + opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) + opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) + flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) + + assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test_flash_attention_for_opt() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..7eff506e0c28 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,13 +3,12 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): +def build_model(model_fn, enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..af25ba6a8c6c 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,9 +6,9 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, + check_state_dict_equal, clear_cache_before_run, parameterize, rerun_if_address_is_in_use, @@ -43,53 +43,44 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_weight = shard_opt_model.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def check_OPTModel(enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() -def check_OPTModel(rank, world_size, port): +def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + check_OPTModel() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(check_OPTModel, 4) + spawn(run_dist, 4) if __name__ == '__main__': From ac179b5c7ea8d11b77711c1eae1b916a72f656fc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:13:19 +0800 Subject: [PATCH 02/14] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 161 +++++++++--------- colossalai/shardformer/shard/shard_config.py | 14 +- tests/test_shardformer/test_model/_utils.py | 9 +- .../test_model/test_shard_opt.py | 38 +++-- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 725def7f66ab..5e1c1cf36ed4 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -35,90 +35,94 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) # use flash attention if self.shard_config.enable_flash_attention: - del base_policy[OPTAttention] - new_item = { - OPTAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]), - } - base_policy.update(new_item) - - return base_policy + policy[OPTAttention] = ModulePolicyDescription( + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="out_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]) + + return policy def postprocess(self): return self.model @@ -136,15 +140,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) - policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0e826d38ddd0..792e9a3f7ad4 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False @@ -34,12 +35,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) def _turn_on_all_optimization(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7eff506e0c28..44a8f841a6d4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,17 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_flash_attention=False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index af25ba6a8c6c..9ca0669340a5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, check_state_dict_equal, @@ -43,44 +44,55 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def check_OPTModel(enable_flash_attention): +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() -def run_dist(rank, world_size, port): +def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_OPTModel() - torch.cuda.empty_cache() + run_opt_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(run_dist, 4) + spawn(check_OPTModel, 4) if __name__ == '__main__': From 00c4a826e2fe0fcc6e652b87fdb80e60e83fdbfe Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:21:08 +0800 Subject: [PATCH 03/14] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 6 ------ colossalai/shardformer/shard/shard_config.py | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5e1c1cf36ed4..7a838a632776 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -79,12 +79,6 @@ def module_policy(self): ]) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 792e9a3f7ad4..4076becdbbc9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -40,6 +40,9 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() def _turn_on_all_optimization(self): """ From 27526d943654b10f75304d3f08e1f02dc8134dfc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 18:22:23 +0800 Subject: [PATCH 04/14] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 2 +- .../layer/flash_attention/__init__.py | 3 --- .../flash_attention/flash_attention_opt.py | 5 +++-- colossalai/shardformer/policies/opt.py | 21 +++---------------- 4 files changed, 7 insertions(+), 24 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index fef65fc5fa52..a994e8f37e91 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,6 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention import opt_flash_attention_forward +from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py deleted file mode 100644 index e3ee83b005c4..000000000000 --- a/colossalai/shardformer/layer/flash_attention/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .flash_attention_opt import opt_flash_attention_forward - -__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py index 3f0445078357..8b084fcb72ae 100644 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -68,8 +68,9 @@ def opt_flash_attention_forward( try: from xformers.ops import memory_efficient_attention as me_attention - except ImportError: - print("Error: xformers module is not installed. Please install it to use flash attention.") + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, key_states, value_states, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 7a838a632776..9ca0373c69da 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -97,24 +97,9 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription( - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="out_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]) + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': opt_flash_attention_forward, + }) return policy From 70535bd398f25660f60a6912856151172fa128d5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:16:46 +0800 Subject: [PATCH 05/14] [shardformer] move to modeling --- colossalai/shardformer/layer/__init__.py | 15 +++- .../flash_attention/flash_attention_opt.py | 88 ------------------- colossalai/shardformer/policies/opt.py | 9 +- .../test_flash_attention_for_opt.py | 38 -------- 4 files changed, 13 insertions(+), 137 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py delete mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index a994e8f37e91..72e002497fa5 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,13 +1,20 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', + 'DropoutForParallelInput', + 'DropoutForReplicatedInput', + "cross_entropy_1d", + 'FusedLayerNorm', + 'FusedRMSNorm', ] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py deleted file mode 100644 index 8b084fcb72ae..000000000000 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional, Tuple - -import torch - -__all__ = ['opt_flash_attention_forward'] - - -def opt_flash_attention_forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) - # get query proj - # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) - query_states = self.q_proj(hidden_states).view(*attention_input_shape) - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") - attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=attention_mask, - p=self.dropout, - scale=self.scaling) - - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9ca0373c69da..fa7bf1befd44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,12 +1,7 @@ -from colossalai.shardformer.layer import ( - FusedLayerNorm, - Linear1D_Col, - Linear1D_Row, - VocabParallelEmbedding1D, - opt_flash_attention_forward, -) +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.opt import opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py deleted file mode 100644 index 7f196e86a253..000000000000 --- a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py +++ /dev/null @@ -1,38 +0,0 @@ -from copy import deepcopy - -import torch -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.opt.modeling_opt import OPTAttention - -from colossalai.shardformer.layer import opt_flash_attention_forward - - -def test_flash_attention_for_opt(): - BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 - - # generate input - hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") - - opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, - bias=True).to("cuda") - - opt_flash_attention = deepcopy(opt_attention) - setattr(opt_flash_attention, 'forward', - opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) - opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) - flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) - - assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) - - -if __name__ == '__main__': - test_flash_attention_for_opt() From 794dd86a6e63e97df35cad6c746c08caace758e6 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:19:51 +0800 Subject: [PATCH 06/14] [shardformer] move to modeling --- colossalai/shardformer/modeling/opt.py | 88 ++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 colossalai/shardformer/modeling/opt.py diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..8b084fcb72ae --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value From a33f6ce306ec102be6daea1bba1811348a269908 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 11 Jul 2023 19:48:45 +0800 Subject: [PATCH 07/14] [shardformer] t5 support flash attention --- .../engine/schedule/_pipeline_schedule.py | 9 +- colossalai/shardformer/modeling/t5.py | 125 ++++++++++++++++++ colossalai/shardformer/policies/t5.py | 7 + tests/kit/model_zoo/transformers/t5.py | 6 +- .../test_model/test_shard_t5.py | 5 +- 5 files changed, 141 insertions(+), 11 deletions(-) create mode 100644 colossalai/shardformer/modeling/t5.py diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 9fc301a26559..25fe24023806 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -605,12 +605,10 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: - num_warmup_microbatches = \ - (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = \ - num_microbatches - num_warmup_microbatches + num_microbatches_remaining = num_microbatches - num_warmup_microbatches def get_model_chunk_id(microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" @@ -629,8 +627,7 @@ def _forward_step_helper(microbatch_id): # forward step if gpc.is_pipeline_first_stage(): - if len(input_objs[model_chunk_id]) == \ - len(output_objs[model_chunk_id]): + if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]): input_objs[model_chunk_id].append(None) input_obj = input_objs[model_chunk_id][-1] output_obj = self._forward_step(engine, diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py new file mode 100644 index 000000000000..b3a151adf9a2 --- /dev/null +++ b/colossalai/shardformer/modeling/t5.py @@ -0,0 +1,125 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['get_t5_forward'] +def get_t5_forward(): + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def t5_flash_attention_forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + + def unshape(states): + """reshape""" + return states.view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=1) + elif past_key_value.shape[1] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + position_bias_masked = position_bias_masked.contiguous() + attn_output = me_attention(query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0) + attn_output = unshape(attn_output) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + return outputs + + return t5_flash_attention_forward \ No newline at end of file diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index cde59ab77042..5463878b04b3 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -10,6 +10,7 @@ from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from ..modeling.t5 import get_t5_forward __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -161,6 +162,12 @@ def module_policy(self): suffix="final_layer_norm", target_module=FusedRMSNorm), policy=policy, target_key=T5Stack) + + # use flash attention + if self.shard_config.enable_flash_attention: + policy[T5Attention] = ModulePolicyDescription(method_replacement={ + 'forward': get_t5_forward(), + }) return policy def postprocess(self): diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 689db2c40abb..d9d41fb0f451 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -16,7 +16,7 @@ def data_gen_for_encoder_only(): # config = T5Config(decoder_start_token_id=0) # tokenizer = T5Tokenizer.from_pretrained("t5-small") # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids - input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long() return dict(input_ids=input_ids) @@ -25,7 +25,7 @@ def data_gen_for_conditional_generation(): # # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() - labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long() data['labels'] = labels return data @@ -35,7 +35,7 @@ def data_gen_for_t5_model(): # # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() - decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long() data['decoder_input_ids'] = decoder_input_ids return data diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 0762dc09e5af..ff4b71fe76a4 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -82,10 +82,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From c48544b5edecc1733e0e2c4d5f3de8a8ec44912f Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 11 Jul 2023 19:50:04 +0800 Subject: [PATCH 08/14] [shardformer] t5 support flash attention --- colossalai/engine/schedule/_pipeline_schedule.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 25fe24023806..d29eaf170ccb 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -605,10 +605,12 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: - num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches = \ + (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = num_microbatches - num_warmup_microbatches + num_microbatches_remaining = \ + num_microbatches - num_warmup_microbatches def get_model_chunk_id(microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" @@ -627,7 +629,8 @@ def _forward_step_helper(microbatch_id): # forward step if gpc.is_pipeline_first_stage(): - if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]): + if len(input_objs[model_chunk_id]) == \ + len(output_objs[model_chunk_id]): input_objs[model_chunk_id].append(None) input_obj = input_objs[model_chunk_id][-1] output_obj = self._forward_step(engine, @@ -827,4 +830,4 @@ def _backward_step_helper(microbatch_id): output, label = pack_return_tensors(return_tensors) return output, label, accum_loss else: - return None, None, accum_loss + return None, None, accum_loss \ No newline at end of file From 1f4d29802589de0c7b97577f7141624abb6e6bf7 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 12 Jul 2023 09:36:50 +0800 Subject: [PATCH 09/14] fix typo --- colossalai/shardformer/modeling/t5.py | 46 +++++++++++++++----------- tests/kit/model_zoo/transformers/t5.py | 27 +++++++-------- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index b3a151adf9a2..4c84ad29b931 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -3,12 +3,14 @@ import torch __all__ = ['get_t5_forward'] + + def get_t5_forward(): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - + def t5_flash_attention_forward( self, hidden_states, @@ -27,6 +29,9 @@ def t5_flash_attention_forward( # Input is (batch_size, seq_length, dim) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + import pathlib + pathlib.Path("/home/lcjmy/code/personal/ColossalAI/colossalai/shardformer/modeling/mask.txt").write_text( + str(mask) + str(mask.shape)) batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length @@ -43,11 +48,11 @@ def t5_flash_attention_forward( def shape(states): """projection""" return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) - + def unshape(states): """reshape""" return states.view(batch_size, -1, self.inner_dim) - + def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" if key_value_states is None: @@ -74,23 +79,21 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # cross-attn hidden_states = past_key_value return hidden_states - + # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + key_states = project(hidden_states, self.k, key_value_states, + past_key_value[0] if past_key_value is not None else None) + value_states = project(hidden_states, self.v, key_value_states, + past_key_value[1] if past_key_value is not None else None) if position_bias is None: if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype - ) + position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), + device=query_states.device, + dtype=query_states.dtype) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: @@ -99,10 +102,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = position_bias[:, :, -hidden_states.size(1):, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -112,7 +115,12 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention(query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0) + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=position_bias_masked, + p=self.dropout, + scale=1.0) attn_output = unshape(attn_output) attn_output = self.o(attn_output) @@ -121,5 +129,5 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) return outputs - - return t5_flash_attention_forward \ No newline at end of file + + return t5_flash_attention_forward diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index d9d41fb0f451..78e89bcd49a9 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -17,7 +17,8 @@ def data_gen_for_encoder_only(): # tokenizer = T5Tokenizer.from_pretrained("t5-small") # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long() - return dict(input_ids=input_ids) + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) def data_gen_for_conditional_generation(): @@ -61,15 +62,15 @@ def data_gen_for_t5_model(): output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_t5_model, model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_for_conditional_generation', - model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_conditional_generation, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_encoder_model', - model_fn=lambda: transformers.T5EncoderModel(config), - data_gen_fn=data_gen_for_encoder_only, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_encoder_only, - model_attribute=ModelAttribute(has_control_flow=True)) +# model_zoo.register(name='transformers_t5_for_conditional_generation', +# model_fn=lambda: transformers.T5ForConditionalGeneration(config), +# data_gen_fn=data_gen_for_conditional_generation, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn_for_conditional_generation, +# model_attribute=ModelAttribute(has_control_flow=True)) +# model_zoo.register(name='transformers_t5_encoder_model', +# model_fn=lambda: transformers.T5EncoderModel(config), +# data_gen_fn=data_gen_for_encoder_only, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn_for_encoder_only, +# model_attribute=ModelAttribute(has_control_flow=True)) From 671674ff269e32f81c91e267d4d90d496f9b9909 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 12 Jul 2023 09:37:56 +0800 Subject: [PATCH 10/14] fix typo --- colossalai/shardformer/modeling/t5.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 4c84ad29b931..130178cd7ba7 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -6,6 +6,7 @@ def get_t5_forward(): + try: from xformers.ops import memory_efficient_attention as me_attention except: From 9d75362557be389a417f3ceeb75b7b718110bd90 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 12 Jul 2023 10:46:06 +0800 Subject: [PATCH 11/14] fix typo --- colossalai/shardformer/modeling/t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 130178cd7ba7..81fb6d28ad05 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import torch From 819981ca81e936037b98009daf7db45cf4dc1044 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 12 Jul 2023 10:47:26 +0800 Subject: [PATCH 12/14] fix typo --- tests/kit/model_zoo/transformers/t5.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 78e89bcd49a9..ec5ee05e8f91 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -62,15 +62,15 @@ def data_gen_for_t5_model(): output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_t5_model, model_attribute=ModelAttribute(has_control_flow=True)) -# model_zoo.register(name='transformers_t5_for_conditional_generation', -# model_fn=lambda: transformers.T5ForConditionalGeneration(config), -# data_gen_fn=data_gen_for_conditional_generation, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn_for_conditional_generation, -# model_attribute=ModelAttribute(has_control_flow=True)) -# model_zoo.register(name='transformers_t5_encoder_model', -# model_fn=lambda: transformers.T5EncoderModel(config), -# data_gen_fn=data_gen_for_encoder_only, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn_for_encoder_only, -# model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_t5_for_conditional_generation', + model_fn=lambda: transformers.T5ForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_t5_encoder_model', + model_fn=lambda: transformers.T5EncoderModel(config), + data_gen_fn=data_gen_for_encoder_only, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, + model_attribute=ModelAttribute(has_control_flow=True)) From 8ee4aaeeed75d2c8493ae6510c6e258605f8c912 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 12 Jul 2023 10:51:20 +0800 Subject: [PATCH 13/14] fix typo --- tests/kit/model_zoo/transformers/t5.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index ec5ee05e8f91..435cb6f46937 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -33,7 +33,6 @@ def data_gen_for_conditional_generation(): def data_gen_for_t5_model(): # decoder_inputs_ids is obtained with the following code - # # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long() From 217a8ffd2c32b1d4e6c6dfae1f7ab3365c26dc40 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 12 Jul 2023 11:16:01 +0800 Subject: [PATCH 14/14] fix typo --- colossalai/engine/schedule/_pipeline_schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index d29eaf170ccb..9fc301a26559 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -830,4 +830,4 @@ def _backward_step_helper(microbatch_id): output, label = pack_return_tensors(return_tensors) return output, label, accum_loss else: - return None, None, accum_loss \ No newline at end of file + return None, None, accum_loss