From cb9d34b7da30f2326c96eafd3df02f14545d0d59 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 16:28:28 +0800 Subject: [PATCH 01/16] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 3 +- .../layer/flash_attention/__init__.py | 3 + .../flash_attention/flash_attention_opt.py | 87 +++++++++++ colossalai/shardformer/policies/opt.py | 146 +++++++++++------- colossalai/shardformer/shard/shard_config.py | 12 +- .../test_flash_attention_for_opt.py | 38 +++++ tests/test_shardformer/test_model/_utils.py | 5 +- .../test_model/test_shard_opt.py | 39 ++--- 8 files changed, 241 insertions(+), 92 deletions(-) create mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py create mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py create mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..fef65fc5fa52 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D +from .flash_attention import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm @@ -8,5 +9,5 @@ __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' ] diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py new file mode 100644 index 000000000000..e3ee83b005c4 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/__init__.py @@ -0,0 +1,3 @@ +from .flash_attention_opt import opt_flash_attention_forward + +__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py new file mode 100644 index 000000000000..3f0445078357 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except ImportError: + print("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..725def7f66ab 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,4 +1,10 @@ -from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, + opt_flash_attention_forward, +) from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -29,67 +35,90 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - policy = {} - - if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + base_policy = { + OPTDecoder: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + ]) - return policy + # use flash attention + if self.shard_config.enable_flash_attention: + del base_policy[OPTAttention] + new_item = { + OPTAttention: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]), + } + base_policy.update(new_item) + + return base_policy def postprocess(self): return self.model @@ -107,12 +136,15 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } - if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..0e826d38ddd0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,14 +13,14 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None - enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -34,11 +34,8 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: @@ -50,3 +47,4 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py new file mode 100644 index 000000000000..7f196e86a253 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py @@ -0,0 +1,38 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.opt.modeling_opt import OPTAttention + +from colossalai.shardformer.layer import opt_flash_attention_forward + + +def test_flash_attention_for_opt(): + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 + + # generate input + hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + + opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, + bias=True).to("cuda") + + opt_flash_attention = deepcopy(opt_attention) + setattr(opt_flash_attention, 'forward', + opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) + opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) + flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) + + assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test_flash_attention_for_opt() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..7eff506e0c28 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,13 +3,12 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): +def build_model(model_fn, enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..af25ba6a8c6c 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,9 +6,9 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, + check_state_dict_equal, clear_cache_before_run, parameterize, rerun_if_address_is_in_use, @@ -43,53 +43,44 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_weight = shard_opt_model.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def check_OPTModel(enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() -def check_OPTModel(rank, world_size, port): +def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + check_OPTModel() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(check_OPTModel, 4) + spawn(run_dist, 4) if __name__ == '__main__': From ac179b5c7ea8d11b77711c1eae1b916a72f656fc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:13:19 +0800 Subject: [PATCH 02/16] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 161 +++++++++--------- colossalai/shardformer/shard/shard_config.py | 14 +- tests/test_shardformer/test_model/_utils.py | 9 +- .../test_model/test_shard_opt.py | 38 +++-- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 725def7f66ab..5e1c1cf36ed4 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -35,90 +35,94 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) # use flash attention if self.shard_config.enable_flash_attention: - del base_policy[OPTAttention] - new_item = { - OPTAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]), - } - base_policy.update(new_item) - - return base_policy + policy[OPTAttention] = ModulePolicyDescription( + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="out_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]) + + return policy def postprocess(self): return self.model @@ -136,15 +140,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) - policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0e826d38ddd0..792e9a3f7ad4 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False @@ -34,12 +35,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) def _turn_on_all_optimization(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7eff506e0c28..44a8f841a6d4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,17 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_flash_attention=False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index af25ba6a8c6c..9ca0669340a5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, check_state_dict_equal, @@ -43,44 +44,55 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def check_OPTModel(enable_flash_attention): +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() -def run_dist(rank, world_size, port): +def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_OPTModel() - torch.cuda.empty_cache() + run_opt_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(run_dist, 4) + spawn(check_OPTModel, 4) if __name__ == '__main__': From 00c4a826e2fe0fcc6e652b87fdb80e60e83fdbfe Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:21:08 +0800 Subject: [PATCH 03/16] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 6 ------ colossalai/shardformer/shard/shard_config.py | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5e1c1cf36ed4..7a838a632776 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -79,12 +79,6 @@ def module_policy(self): ]) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 792e9a3f7ad4..4076becdbbc9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -40,6 +40,9 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() def _turn_on_all_optimization(self): """ From 27526d943654b10f75304d3f08e1f02dc8134dfc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 18:22:23 +0800 Subject: [PATCH 04/16] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 2 +- .../layer/flash_attention/__init__.py | 3 --- .../flash_attention/flash_attention_opt.py | 5 +++-- colossalai/shardformer/policies/opt.py | 21 +++---------------- 4 files changed, 7 insertions(+), 24 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index fef65fc5fa52..a994e8f37e91 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,6 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention import opt_flash_attention_forward +from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py deleted file mode 100644 index e3ee83b005c4..000000000000 --- a/colossalai/shardformer/layer/flash_attention/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .flash_attention_opt import opt_flash_attention_forward - -__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py index 3f0445078357..8b084fcb72ae 100644 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -68,8 +68,9 @@ def opt_flash_attention_forward( try: from xformers.ops import memory_efficient_attention as me_attention - except ImportError: - print("Error: xformers module is not installed. Please install it to use flash attention.") + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, key_states, value_states, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 7a838a632776..9ca0373c69da 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -97,24 +97,9 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription( - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="out_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]) + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': opt_flash_attention_forward, + }) return policy From 70535bd398f25660f60a6912856151172fa128d5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:16:46 +0800 Subject: [PATCH 05/16] [shardformer] move to modeling --- colossalai/shardformer/layer/__init__.py | 15 +++- .../flash_attention/flash_attention_opt.py | 88 ------------------- colossalai/shardformer/policies/opt.py | 9 +- .../test_flash_attention_for_opt.py | 38 -------- 4 files changed, 13 insertions(+), 137 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py delete mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index a994e8f37e91..72e002497fa5 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,13 +1,20 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', + 'DropoutForParallelInput', + 'DropoutForReplicatedInput', + "cross_entropy_1d", + 'FusedLayerNorm', + 'FusedRMSNorm', ] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py deleted file mode 100644 index 8b084fcb72ae..000000000000 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional, Tuple - -import torch - -__all__ = ['opt_flash_attention_forward'] - - -def opt_flash_attention_forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) - # get query proj - # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) - query_states = self.q_proj(hidden_states).view(*attention_input_shape) - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") - attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=attention_mask, - p=self.dropout, - scale=self.scaling) - - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9ca0373c69da..fa7bf1befd44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,12 +1,7 @@ -from colossalai.shardformer.layer import ( - FusedLayerNorm, - Linear1D_Col, - Linear1D_Row, - VocabParallelEmbedding1D, - opt_flash_attention_forward, -) +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.opt import opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py deleted file mode 100644 index 7f196e86a253..000000000000 --- a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py +++ /dev/null @@ -1,38 +0,0 @@ -from copy import deepcopy - -import torch -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.opt.modeling_opt import OPTAttention - -from colossalai.shardformer.layer import opt_flash_attention_forward - - -def test_flash_attention_for_opt(): - BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 - - # generate input - hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") - - opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, - bias=True).to("cuda") - - opt_flash_attention = deepcopy(opt_attention) - setattr(opt_flash_attention, 'forward', - opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) - opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) - flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) - - assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) - - -if __name__ == '__main__': - test_flash_attention_for_opt() From 794dd86a6e63e97df35cad6c746c08caace758e6 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:19:51 +0800 Subject: [PATCH 06/16] [shardformer] move to modeling --- colossalai/shardformer/modeling/opt.py | 88 ++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 colossalai/shardformer/modeling/opt.py diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..8b084fcb72ae --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value From cd14ceaa38a353920f0b28b76d88619a3345c591 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 14 Jul 2023 16:16:44 +0800 Subject: [PATCH 07/16] [shardformer] bloom support jit fused operator --- colossalai/shardformer/modeling/bloom.py | 181 ++++++++++++++++-- colossalai/shardformer/policies/bloom.py | 27 ++- colossalai/shardformer/shard/shard_config.py | 2 + tests/test_shardformer/test_model/_utils.py | 6 +- .../test_model/test_shard_bloom.py | 6 +- 5 files changed, 199 insertions(+), 23 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 4b6458ae0627..16b451227833 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -71,15 +71,15 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return build_bloom_alibi_tensor -def get_bloom_forward(): + +def get_flash_attention_forward(enabel_jit_fused=False): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.bloom.modeling_bloom import dropout_add - def bloom_flash_attention_forward( + def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, @@ -88,7 +88,8 @@ def bloom_flash_attention_forward( layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, - output_attentions: bool = False,): + output_attentions: bool = False, + ): fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) @@ -114,30 +115,180 @@ def bloom_flash_attention_forward( present = (key_layer, value_layer) else: present = None - + tgt_len = key_layer.size()[1] - attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), dtype=torch.float32, device=query_layer.device, requires_grad=True) - attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min) - - context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=attention_numerical_mask, scale=self.inv_norm_factor, p=self.attention_dropout.p) + attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), + dtype=torch.float32, + device=query_layer.device, + requires_grad=True) + attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, + kv_length) * self.beta + attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, + torch.finfo(torch.float32).min) + + context_layer = me_attention(query_layer, + key_layer, + value_layer, + attn_bias=attention_numerical_mask, + scale=self.inv_norm_factor, + p=self.attention_dropout.p) context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) # TODO to replace with the bias_dropout_add function in jit - output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, present, None) return outputs - - return bloom_flash_attention_forward \ No newline at end of file + + return forward + + +def get_jit_fused_attention_forward(): + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, q_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + +def get_jit_fused_mlp_forward(): + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices):int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + return output + + return forward + + +def get_dropout_add_func(): + + from transformers.models.bloom.modeling_bloom import dropout_add + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + return dropout_add(x, residual, prob, training) + + return self_dropout_add + + +def get_jit_fused_dropout_add_func(): + + from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + bias = torch.zeros_like(x) + if training: + return bias_dropout_add_fused_train(x, bias, residual, prob) + return bias_dropout_add_fused_inference(x, bias, residual, prob) + + return self_dropout_add diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index daae3cf881a2..91d8f9053ea9 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -3,7 +3,14 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.bloom import build_bloom_alibi_tensor_fn, get_bloom_forward +from ..modeling.bloom import ( + build_bloom_alibi_tensor_fn, + get_dropout_add_func, + get_flash_attention_forward, + get_jit_fused_attention_forward, + get_jit_fused_dropout_add_func, + get_jit_fused_mlp_forward, +) from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -25,7 +32,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, BloomAttention + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomMLP, BloomModel policy = {} @@ -101,10 +108,22 @@ def module_policy(self): ], policy=policy, target_key=BloomBlock) - + if self.shard_config.enable_flash_attention: policy[BloomAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_bloom_forward(), + 'forward': get_flash_attention_forward(), + 'dropout_add': get_dropout_add_func() + }) + + # enable jit fused operator + if self.shard_config.enable_jit_fused: + policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_mlp_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4076becdbbc9..f9d93609feef 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -22,6 +22,7 @@ class ShardConfig: enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False + enable_jit_fused: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -51,3 +52,4 @@ def _turn_on_all_optimization(self): # you can add all the optimization flag here self.enable_fused_normalization = True self.enable_flash_attention = True + self.enable_jit_fused = True diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 44a8f841a6d4..4dcaac143525 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -6,14 +6,16 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, - enable_flash_attention=False): + enable_flash_attention=False, + enable_jit_fused=False): # create new model org_model = model_fn().cuda() # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention) + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 63438bbe9a91..c17d55a9126c 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -70,10 +70,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 1b4beb014afe07cb2dc9eb5b254c03d3e6771549 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 17 Jul 2023 16:07:48 +0800 Subject: [PATCH 08/16] [shardformer] bloom support jit fused operator --- colossalai/shardformer/modeling/bert.py | 47 ++++++-- colossalai/shardformer/modeling/bloom.py | 25 ++-- colossalai/shardformer/modeling/jit.py | 34 ++++++ colossalai/shardformer/modeling/opt.py | 111 ++++++++++++++++-- colossalai/shardformer/policies/bert.py | 30 ++++- colossalai/shardformer/policies/bloom.py | 10 +- colossalai/shardformer/policies/opt.py | 12 +- .../test_model/test_shard_bert.py | 6 +- .../test_model/test_shard_opt.py | 5 +- 9 files changed, 226 insertions(+), 54 deletions(-) create mode 100644 colossalai/shardformer/modeling/jit.py diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index faa579529ce4..6e2f3debd118 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,5 +1,5 @@ -from typing import Optional, Tuple import math +from typing import Optional, Tuple import torch import torch.distributed as dist @@ -8,14 +8,15 @@ __all__ = ['get_bert_forward'] -def get_bert_forward(): + +def get_bert_flash_attention_forward(): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def bert_flash_attention_forward( + + def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, @@ -67,16 +68,14 @@ def bert_flash_attention_forward( if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1) else: position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -99,8 +98,13 @@ def bert_flash_attention_forward( query_layer = query_layer.permute(0, 2, 1, 3).contiguous() key_layer = key_layer.permute(0, 2, 1, 3).contiguous() value_layer = value_layer.permute(0, 2, 1, 3).contiguous() - - context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale) + + context_layer = me_attention(query_layer, + key_layer, + value_layer, + attn_bias=final_attention_mask, + p=self.dropout.p, + scale=scale) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) @@ -109,5 +113,24 @@ def bert_flash_attention_forward( if self.is_decoder: outputs = outputs + (past_key_value,) return outputs - - return bert_flash_attention_forward \ No newline at end of file + + return forward + + +def _get_jit_fused_output_forward(): + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_bert_self_output_forward(): + return _get_jit_fused_output_forward() + + +def get_jit_fused_bert_output_forward(): + return _get_jit_fused_output_forward() diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 16b451227833..c43fb90af947 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -271,24 +271,15 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. return forward -def get_dropout_add_func(): +def get_jit_fused_bloom_gelu_forward(): - from transformers.models.bloom.modeling_bloom import dropout_add + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction - def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: - return dropout_add(x, residual, prob, training) - - return self_dropout_add - - -def get_jit_fused_dropout_add_func(): - - from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train - - def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: bias = torch.zeros_like(x) - if training: - return bias_dropout_add_fused_train(x, bias, residual, prob) - return bias_dropout_add_fused_inference(x, bias, residual, prob) + if self.training: + return JitGeLUFunction.apply(x, bias) + else: + return self.bloom_gelu_forward(x, bias) - return self_dropout_add + return forward diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py new file mode 100644 index 000000000000..6434348ef823 --- /dev/null +++ b/colossalai/shardformer/modeling/jit.py @@ -0,0 +1,34 @@ +import torch + + +def get_dropout_add_func(): + + from transformers.models.bloom.modeling_bloom import dropout_add + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + return dropout_add(x, residual, prob, training) + + return self_dropout_add + + +def get_jit_fused_dropout_add_func(): + + from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + bias = torch.zeros_like(x) + if training: + return bias_dropout_add_fused_train(x, bias, residual, prob) + return bias_dropout_add_fused_inference(x, bias, residual, prob) + + return self_dropout_add + + +def get_jit_fused_gelu_forward_func(): + + from colossalai.kernel.jit.bias_gelu import bias_gelu + + def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: + return bias_gelu(bias, x) + + return bloom_gelu_forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index c02c17d023d7..cb0a0a7db8c0 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,18 +1,19 @@ from typing import Optional, Tuple import torch +from torch import nn __all__ = ['get_opt_forward'] -def get_opt_forward(): - +def get_opt_flash_attention_forward(): + try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def opt_flash_attention_forward( + + def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, @@ -67,7 +68,7 @@ def opt_flash_attention_forward( if layer_head_mask != None: if layer_head_mask.size() != (self.num_heads,): raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") + f" {layer_head_mask.size()}") if attention_mask != None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( @@ -75,11 +76,11 @@ def opt_flash_attention_forward( attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=attention_mask, - p=self.dropout, - scale=self.scaling) + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) @@ -89,5 +90,91 @@ def opt_flash_attention_forward( attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value - - return opt_flash_attention_forward + + return forward + + +def get_jit_fused_opt_decoder_layer_forward(): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # hidden_states = residual + hidden_states + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # hidden_states = (residual + hidden_states).view(hidden_states_shape) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index a213538f8809..e9f6a3f094cb 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -3,8 +3,13 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ +from ..modeling.bert import ( + get_bert_flash_attention_forward, + get_jit_fused_bert_output_forward, + get_jit_fused_bert_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.bert import get_bert_forward __all__ = [ 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', @@ -32,7 +37,13 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertSelfAttention + from transformers.models.bert.modeling_bert import ( + BertEmbeddings, + BertLayer, + BertOutput, + BertSelfAttention, + BertSelfOutput, + ) policy = {} @@ -121,11 +132,22 @@ def module_policy(self): )], policy=policy, target_key=BertEmbeddings) - + # use flash attention if self.shard_config.enable_flash_attention: policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_bert_forward(), + 'forward': get_bert_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bert_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BertOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bert_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 91d8f9053ea9..4dc673213a20 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -5,12 +5,12 @@ from .._utils import getattr_, setattr_ from ..modeling.bloom import ( build_bloom_alibi_tensor_fn, - get_dropout_add_func, get_flash_attention_forward, get_jit_fused_attention_forward, - get_jit_fused_dropout_add_func, + get_jit_fused_bloom_gelu_forward, get_jit_fused_mlp_forward, ) +from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -32,7 +32,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomMLP, BloomModel + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel policy = {} @@ -125,6 +125,10 @@ def module_policy(self): 'forward': get_jit_fused_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), }) + policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_gelu_forward(), + 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), + }) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e960227415cc..f363096158ae 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,7 +1,8 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ -from ..modeling.opt import get_opt_forward +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -93,7 +94,14 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: policy[OPTAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_opt_forward(), + 'forward': get_opt_flash_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_opt_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 865ba2fa7126..a025b0cee0b3 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -70,10 +70,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 9ca0669340a5..0eb04960460b 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -73,11 +73,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention) + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 1106da87eb95d4e79d8421c7c12d51e2333ae2f4 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 17 Jul 2023 16:26:49 +0800 Subject: [PATCH 09/16] [shardformer] bloom support jit fused operator --- colossalai/shardformer/modeling/bert.py | 2 -- colossalai/shardformer/modeling/bloom.py | 6 ++--- colossalai/shardformer/modeling/gpt2.py | 33 ++++++++++++++---------- colossalai/shardformer/modeling/llama.py | 33 +++++------------------- colossalai/shardformer/modeling/opt.py | 2 -- colossalai/shardformer/policies/bloom.py | 12 ++++----- colossalai/shardformer/policies/gpt2.py | 10 +++---- colossalai/shardformer/policies/llama.py | 10 +++---- 8 files changed, 45 insertions(+), 63 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 6e2f3debd118..689c142ce277 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -6,8 +6,6 @@ from torch.distributed import ProcessGroup from torch.nn import functional as F -__all__ = ['get_bert_forward'] - def get_bert_flash_attention_forward(): diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index c43fb90af947..080ace10eb82 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -72,7 +72,7 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return build_bloom_alibi_tensor -def get_flash_attention_forward(enabel_jit_fused=False): +def get_bloom_flash_attention_forward(enabel_jit_fused=False): try: from xformers.ops import memory_efficient_attention as me_attention @@ -154,7 +154,7 @@ def forward( return forward -def get_jit_fused_attention_forward(): +def get_jit_fused_bloom_attention_forward(): def forward( self, @@ -250,7 +250,7 @@ def forward( return forward -def get_jit_fused_mlp_forward(): +def get_jit_fused_bloom_mlp_forward(): def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index a9a38bce235f..a87a3630e867 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -2,17 +2,16 @@ import torch -__all__ = ['get_gpt2_forward'] -def get_gpt2_forward(): +def get_gpt2_flash_attention_forward(): try: from xformers.ops import memory_efficient_attention as me_attention from xformers.ops.fmha.attn_bias import LowerTriangularMask except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def gpt2_flash_attention_forward( + + def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, @@ -30,8 +29,7 @@ def gpt2_flash_attention_forward( if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) @@ -52,7 +50,7 @@ def gpt2_flash_attention_forward( present = (key, value) else: present = None - + attn_bias = None if not self.is_cross_attention: attn_bias = LowerTriangularMask() @@ -62,20 +60,26 @@ def gpt2_flash_attention_forward( else: batch_size, _, tgt_len, src_len = attention_mask.size() attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous() - - scale = value.size(-1) ** -0.5 + + scale = value.size(-1)**-0.5 if self.scale_attn_by_inverse_layer_idx: scale = scale * (1 / float(self.layer_idx + 1)) - attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) - + attn_output = me_attention(query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=self.attn_dropout.p, + scale=scale) + attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) return outputs - - return gpt2_flash_attention_forward + + return forward + def split_heads(tensor, num_heads, attn_head_size): """ @@ -85,9 +89,10 @@ def split_heads(tensor, num_heads, attn_head_size): tensor = tensor.view(new_shape) return tensor + def merge_heads(tensor, num_heads, attn_head_size): """ Merges attn_head_size dim and num_attn_heads dim into hidden_size """ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) \ No newline at end of file + return tensor.view(new_shape) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index aab56f056436..91e60a7c531c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -2,35 +2,16 @@ import torch -__all__ = ['get_llama_forward'] - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def get_llama_forward(): +def get_llama_flash_attention_forward(): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def llama_flash_attention_forward( + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -74,10 +55,10 @@ def llama_flash_attention_forward( attn_output = me_attention(query_states, key_states, value_states, attn_bias=attention_mask) if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}") + f" {attn_output.size()}") attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - - return llama_flash_attention_forward + + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index cb0a0a7db8c0..6ef6547b3e16 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -3,8 +3,6 @@ import torch from torch import nn -__all__ = ['get_opt_forward'] - def get_opt_flash_attention_forward(): diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 4dc673213a20..b2c86ec6c632 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -5,10 +5,10 @@ from .._utils import getattr_, setattr_ from ..modeling.bloom import ( build_bloom_alibi_tensor_fn, - get_flash_attention_forward, - get_jit_fused_attention_forward, + get_bloom_flash_attention_forward, + get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, - get_jit_fused_mlp_forward, + get_jit_fused_bloom_mlp_forward, ) from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -111,18 +111,18 @@ def module_policy(self): if self.shard_config.enable_flash_attention: policy[BloomAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_flash_attention_forward(), + 'forward': get_bloom_flash_attention_forward(), 'dropout_add': get_dropout_add_func() }) # enable jit fused operator if self.shard_config.enable_jit_fused: policy[BloomAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_attention_forward(), + 'forward': get_jit_fused_bloom_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), }) policy[BloomMLP] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_mlp_forward(), + 'forward': get_jit_fused_bloom_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), }) policy[BloomGelu] = ModulePolicyDescription(method_replacement={ diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index cd11bdb39f94..f4cc18024f6d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -3,8 +3,8 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ +from ..modeling.gpt2 import get_gpt2_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.gpt2 import get_gpt2_forward __all__ = [ 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', @@ -30,7 +30,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model, GPT2Attention + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} @@ -107,11 +107,11 @@ def module_policy(self): ], policy=policy, target_key=GPT2Block) - + if self.shard_config.enable_flash_attention: policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ - 'forward': get_gpt2_forward(), - }) + 'forward': get_gpt2_flash_attention_forward(), + }) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 7664492f7686..c4ed5bb0e342 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,8 +4,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from ..modeling.llama import get_llama_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.llama import get_llama_forward __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -27,7 +27,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaAttention + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel policy = {} @@ -99,11 +99,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), policy=policy, target_key=LlamaModel) - + if self.shard_config.enable_flash_attention: policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_llama_forward(), - }) + 'forward': get_llama_flash_attention_forward(), + }) return policy From c3c2af3096ef23c26976f2df02ae528e858ca827 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 17 Jul 2023 17:29:56 +0800 Subject: [PATCH 10/16] [shardformer] t5 support jit fused operator --- colossalai/shardformer/modeling/t5.py | 84 ++++++++++++++++++- colossalai/shardformer/policies/t5.py | 25 +++++- .../test_model/test_shard_t5.py | 6 +- 3 files changed, 107 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index a23b24e6f571..9e0519f4f1b1 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -2,17 +2,17 @@ import torch -__all__ = ['get_t5_forward'] +__all__ = ['get_t5_flash_attention_forward'] -def get_t5_forward(): +def get_t5_flash_attention_forward(): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - def t5_flash_attention_forward( + def forward( self, hidden_states, mask=None, @@ -128,4 +128,80 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs - return t5_flash_attention_forward + return forward + + +def get_jit_fused_T5_layer_ff_forward(): + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training) + # hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + return forward + + +def get_T5_layer_self_attention_forward(): + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + # hidden_states = hidden_states + self.dropout(attention_output[0]) + hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward + + +def get_T5_layer_cross_attention_forward(): + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + # layer_output = hidden_states + self.dropout(attention_output[0]) + layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 5463878b04b3..846451ab57ee 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -9,8 +9,14 @@ from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription from .._utils import getattr_, setattr_ +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.t5 import ( + get_jit_fused_T5_layer_ff_forward, + get_t5_flash_attention_forward, + get_T5_layer_cross_attention_forward, + get_T5_layer_self_attention_forward, +) from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.t5 import get_t5_forward __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -166,7 +172,22 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: policy[T5Attention] = ModulePolicyDescription(method_replacement={ - 'forward': get_t5_forward(), + 'forward': get_t5_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_T5_layer_ff_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_T5_layer_self_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_T5_layer_cross_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ff4b71fe76a4..cf2efeb11f43 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -83,10 +83,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 58560847bdb583016dc192f165441992081f4718 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 17 Jul 2023 17:39:48 +0800 Subject: [PATCH 11/16] [shardformer] t5 support jit fused operator --- colossalai/shardformer/modeling/t5.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 9e0519f4f1b1..d54037ac01ff 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -137,7 +137,6 @@ def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training) - # hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states return forward @@ -165,7 +164,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - # hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs @@ -199,7 +197,6 @@ def forward( query_length=query_length, output_attentions=output_attentions, ) - # layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs From 13ea27e4774c1f032c1afc759023a20a1a29f6d4 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 18 Jul 2023 14:08:50 +0800 Subject: [PATCH 12/16] [shardformer] t5 support jit fused operator --- colossalai/shardformer/modeling/opt.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index b57837dacac5..d0b4552ef49a 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -131,8 +131,6 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) - # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # hidden_states = residual + hidden_states hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) @@ -153,9 +151,6 @@ def forward( hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) - # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - # hidden_states = (residual + hidden_states).view(hidden_states_shape) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape) From 21d8c8d5e743cb657b7b97a5a480687d3daddd91 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 18 Jul 2023 16:39:40 +0800 Subject: [PATCH 13/16] [shardformer] add roadmap of flash attention --- colossalai/shardformer/README.md | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index c9650ce4f712..d0c368367071 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -105,7 +105,20 @@ We will follow this roadmap to develop Shardformer: - [ ] Whisper - [ ] Multi-modal - [ ] To be added - +- [ ] flash attention support + - [ ] NLP + - [x] BERT + - [x] T5 + - [x] LlaMa + - [x] GPT2 + - [x] OPT + - [x] BLOOM + - [ ] GLM + - [ ] RoBERTa + - [ ] ALBERT + - [ ] ERNIE + - [ ] GPT Neo + - [ ] GPT-J ## 💡 API Design We will discuss the major components of `ShardFormer` below to help you better understand how things work. @@ -396,7 +409,7 @@ In the case of using 2 GPUs, the training times are as follows.

-In the case of using 4 GPUs, the training times are as follows. +In the case of using 4 GPUs, the training times are as follows. | N_CTX | org_model | shard_model | | :------: | :-----: | :-----: | @@ -419,7 +432,7 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. | accuracy | f1 | loss | GPU number | model shard | | :------: | :-----: | :-----: | :--------: | :---------: | From c6c815d3408674b481806bf8ab5b488f939cbc20 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 18 Jul 2023 16:40:22 +0800 Subject: [PATCH 14/16] [shardformer] add roadmap of flash attention --- colossalai/shardformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index d0c368367071..0c1772b41c4d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -105,7 +105,7 @@ We will follow this roadmap to develop Shardformer: - [ ] Whisper - [ ] Multi-modal - [ ] To be added -- [ ] flash attention support +- [ ] Flash attention support - [ ] NLP - [x] BERT - [x] T5 From 4e92f74c12d9e453c1bd933fc9a3d28cfcc0bd23 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 18 Jul 2023 16:42:18 +0800 Subject: [PATCH 15/16] [shardformer] add roadmap of flash attention --- colossalai/shardformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 0c1772b41c4d..f05b0d1180fb 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -105,7 +105,7 @@ We will follow this roadmap to develop Shardformer: - [ ] Whisper - [ ] Multi-modal - [ ] To be added -- [ ] Flash attention support +- [ ] Flash Attention Support - [ ] NLP - [x] BERT - [x] T5 From ab53df9e58699030ab39b50c3fde188b17d04523 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 19 Jul 2023 14:51:37 +0800 Subject: [PATCH 16/16] [shardformer] add type hint to 'self' param of forward --- colossalai/shardformer/modeling/bert.py | 22 ++++--- colossalai/shardformer/modeling/bloom.py | 15 +++-- colossalai/shardformer/modeling/gpt2.py | 4 +- colossalai/shardformer/modeling/llama.py | 4 +- colossalai/shardformer/modeling/opt.py | 8 ++- colossalai/shardformer/modeling/t5.py | 73 +++++++++++++----------- 6 files changed, 77 insertions(+), 49 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 689c142ce277..b69ac63d6a8c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -13,9 +13,10 @@ def get_bert_flash_attention_forward(): from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bert.modeling_bert import BertAttention def forward( - self, + self: BertAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -115,9 +116,11 @@ def forward( return forward -def _get_jit_fused_output_forward(): +def get_jit_fused_bert_self_output_forward(): + + from transformers.models.bert.modeling_bert import BertSelfOutput - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.LayerNorm(hidden_states) @@ -126,9 +129,14 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return forward -def get_jit_fused_bert_self_output_forward(): - return _get_jit_fused_output_forward() +def get_jit_fused_bert_output_forward(): + from transformers.models.bert.modeling_bert import BertOutput -def get_jit_fused_bert_output_forward(): - return _get_jit_fused_output_forward() + def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 080ace10eb82..a586d39ec837 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -78,9 +78,10 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False): from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.bloom.modeling_bloom import BloomAttention def forward( - self, + self: BloomAttention, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, @@ -156,8 +157,10 @@ def forward( def get_jit_fused_bloom_attention_forward(): + from transformers.models.bloom.modeling_bloom import BloomAttention + def forward( - self, + self: BloomAttention, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, @@ -252,7 +255,9 @@ def forward( def get_jit_fused_bloom_mlp_forward(): - def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + from transformers.models.bloom.modeling_bloom import BloomMLP + + def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) if self.pretraining_tp > 1 and self.slow_but_exact: @@ -273,9 +278,11 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. def get_jit_fused_bloom_gelu_forward(): + from transformers.models.bloom.modeling_bloom import BloomGelu + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: bias = torch.zeros_like(x) if self.training: return JitGeLUFunction.apply(x, bias) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index c749a40f9221..ba789eb6458e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -5,10 +5,12 @@ def get_gpt2_flash_attention_forward(): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention def forward( - self, + self: GPT2Attention, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 75fe34bfeead..83aa01716832 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -5,12 +5,12 @@ def get_llama_flash_attention_forward(): - from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention def forward( - self, + self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index d0b4552ef49a..36dd5788038b 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -6,10 +6,12 @@ def get_opt_flash_attention_forward(): + from transformers.models.opt.modeling_opt import OPTAttention + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention def forward( - self, + self: OPTAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -92,8 +94,10 @@ def forward( def get_jit_fused_opt_decoder_layer_forward(): + from transformers.models.opt.modeling_opt import OPTDecoderLayer + def forward( - self, + self: OPTDecoderLayer, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index d54037ac01ff..14befc1f54e9 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -11,19 +11,20 @@ def get_t5_flash_attention_forward(): from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.t5.modeling_t5 import T5Attention def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): + self: T5Attention, + hidden_states: torch.Tensor, + mask: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + layer_head_mask: Optional[torch.Tensor] = None, + query_length: Optional[int] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ @@ -133,7 +134,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): def get_jit_fused_T5_layer_ff_forward(): - def forward(self, hidden_states): + from transformers.models.t5.modeling_t5 import T5LayerFF + + def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training) @@ -144,16 +147,18 @@ def forward(self, hidden_states): def get_T5_layer_self_attention_forward(): + from transformers.models.t5.modeling_t5 import T5LayerSelfAttention + def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): + self: T5LayerSelfAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( normed_hidden_states, @@ -173,18 +178,20 @@ def forward( def get_T5_layer_cross_attention_forward(): + from transformers.models.t5.modeling_t5 import T5LayerCrossAttention + def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): + self: T5LayerCrossAttention, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + query_length: Optional[int] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( normed_hidden_states,