From cb9d34b7da30f2326c96eafd3df02f14545d0d59 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 16:28:28 +0800 Subject: [PATCH 1/9] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 3 +- .../layer/flash_attention/__init__.py | 3 + .../flash_attention/flash_attention_opt.py | 87 +++++++++++ colossalai/shardformer/policies/opt.py | 146 +++++++++++------- colossalai/shardformer/shard/shard_config.py | 12 +- .../test_flash_attention_for_opt.py | 38 +++++ tests/test_shardformer/test_model/_utils.py | 5 +- .../test_model/test_shard_opt.py | 39 ++--- 8 files changed, 241 insertions(+), 92 deletions(-) create mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py create mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py create mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..fef65fc5fa52 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D +from .flash_attention import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm @@ -8,5 +9,5 @@ __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' ] diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py new file mode 100644 index 000000000000..e3ee83b005c4 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/__init__.py @@ -0,0 +1,3 @@ +from .flash_attention_opt import opt_flash_attention_forward + +__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py new file mode 100644 index 000000000000..3f0445078357 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except ImportError: + print("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..725def7f66ab 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,4 +1,10 @@ -from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, + opt_flash_attention_forward, +) from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -29,67 +35,90 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - policy = {} - - if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + base_policy = { + OPTDecoder: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + ]) - return policy + # use flash attention + if self.shard_config.enable_flash_attention: + del base_policy[OPTAttention] + new_item = { + OPTAttention: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]), + } + base_policy.update(new_item) + + return base_policy def postprocess(self): return self.model @@ -107,12 +136,15 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } - if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..0e826d38ddd0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,14 +13,14 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None - enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -34,11 +34,8 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: @@ -50,3 +47,4 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py new file mode 100644 index 000000000000..7f196e86a253 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py @@ -0,0 +1,38 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.opt.modeling_opt import OPTAttention + +from colossalai.shardformer.layer import opt_flash_attention_forward + + +def test_flash_attention_for_opt(): + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 + + # generate input + hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + + opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, + bias=True).to("cuda") + + opt_flash_attention = deepcopy(opt_attention) + setattr(opt_flash_attention, 'forward', + opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) + opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) + flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) + + assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test_flash_attention_for_opt() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..7eff506e0c28 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,13 +3,12 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): +def build_model(model_fn, enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..af25ba6a8c6c 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,9 +6,9 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, + check_state_dict_equal, clear_cache_before_run, parameterize, rerun_if_address_is_in_use, @@ -43,53 +43,44 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_weight = shard_opt_model.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def check_OPTModel(enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() -def check_OPTModel(rank, world_size, port): +def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + check_OPTModel() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(check_OPTModel, 4) + spawn(run_dist, 4) if __name__ == '__main__': From ac179b5c7ea8d11b77711c1eae1b916a72f656fc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:13:19 +0800 Subject: [PATCH 2/9] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 161 +++++++++--------- colossalai/shardformer/shard/shard_config.py | 14 +- tests/test_shardformer/test_model/_utils.py | 9 +- .../test_model/test_shard_opt.py | 38 +++-- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 725def7f66ab..5e1c1cf36ed4 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -35,90 +35,94 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) # use flash attention if self.shard_config.enable_flash_attention: - del base_policy[OPTAttention] - new_item = { - OPTAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]), - } - base_policy.update(new_item) - - return base_policy + policy[OPTAttention] = ModulePolicyDescription( + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="out_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]) + + return policy def postprocess(self): return self.model @@ -136,15 +140,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) - policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0e826d38ddd0..792e9a3f7ad4 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False @@ -34,12 +35,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) def _turn_on_all_optimization(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7eff506e0c28..44a8f841a6d4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,17 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_flash_attention=False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index af25ba6a8c6c..9ca0669340a5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, check_state_dict_equal, @@ -43,44 +44,55 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def check_OPTModel(enable_flash_attention): +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() -def run_dist(rank, world_size, port): +def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_OPTModel() - torch.cuda.empty_cache() + run_opt_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(run_dist, 4) + spawn(check_OPTModel, 4) if __name__ == '__main__': From 00c4a826e2fe0fcc6e652b87fdb80e60e83fdbfe Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:21:08 +0800 Subject: [PATCH 3/9] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 6 ------ colossalai/shardformer/shard/shard_config.py | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5e1c1cf36ed4..7a838a632776 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -79,12 +79,6 @@ def module_policy(self): ]) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 792e9a3f7ad4..4076becdbbc9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -40,6 +40,9 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() def _turn_on_all_optimization(self): """ From 27526d943654b10f75304d3f08e1f02dc8134dfc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 18:22:23 +0800 Subject: [PATCH 4/9] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 2 +- .../layer/flash_attention/__init__.py | 3 --- .../flash_attention/flash_attention_opt.py | 5 +++-- colossalai/shardformer/policies/opt.py | 21 +++---------------- 4 files changed, 7 insertions(+), 24 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index fef65fc5fa52..a994e8f37e91 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,6 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention import opt_flash_attention_forward +from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py deleted file mode 100644 index e3ee83b005c4..000000000000 --- a/colossalai/shardformer/layer/flash_attention/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .flash_attention_opt import opt_flash_attention_forward - -__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py index 3f0445078357..8b084fcb72ae 100644 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -68,8 +68,9 @@ def opt_flash_attention_forward( try: from xformers.ops import memory_efficient_attention as me_attention - except ImportError: - print("Error: xformers module is not installed. Please install it to use flash attention.") + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, key_states, value_states, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 7a838a632776..9ca0373c69da 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -97,24 +97,9 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription( - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="out_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]) + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': opt_flash_attention_forward, + }) return policy From 70535bd398f25660f60a6912856151172fa128d5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:16:46 +0800 Subject: [PATCH 5/9] [shardformer] move to modeling --- colossalai/shardformer/layer/__init__.py | 15 +++- .../flash_attention/flash_attention_opt.py | 88 ------------------- colossalai/shardformer/policies/opt.py | 9 +- .../test_flash_attention_for_opt.py | 38 -------- 4 files changed, 13 insertions(+), 137 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py delete mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index a994e8f37e91..72e002497fa5 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,13 +1,20 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', + 'DropoutForParallelInput', + 'DropoutForReplicatedInput', + "cross_entropy_1d", + 'FusedLayerNorm', + 'FusedRMSNorm', ] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py deleted file mode 100644 index 8b084fcb72ae..000000000000 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional, Tuple - -import torch - -__all__ = ['opt_flash_attention_forward'] - - -def opt_flash_attention_forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) - # get query proj - # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) - query_states = self.q_proj(hidden_states).view(*attention_input_shape) - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") - attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=attention_mask, - p=self.dropout, - scale=self.scaling) - - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9ca0373c69da..fa7bf1befd44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,12 +1,7 @@ -from colossalai.shardformer.layer import ( - FusedLayerNorm, - Linear1D_Col, - Linear1D_Row, - VocabParallelEmbedding1D, - opt_flash_attention_forward, -) +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.opt import opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py deleted file mode 100644 index 7f196e86a253..000000000000 --- a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py +++ /dev/null @@ -1,38 +0,0 @@ -from copy import deepcopy - -import torch -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.opt.modeling_opt import OPTAttention - -from colossalai.shardformer.layer import opt_flash_attention_forward - - -def test_flash_attention_for_opt(): - BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 - - # generate input - hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") - - opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, - bias=True).to("cuda") - - opt_flash_attention = deepcopy(opt_attention) - setattr(opt_flash_attention, 'forward', - opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) - opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) - flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) - - assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) - - -if __name__ == '__main__': - test_flash_attention_for_opt() From 794dd86a6e63e97df35cad6c746c08caace758e6 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:19:51 +0800 Subject: [PATCH 6/9] [shardformer] move to modeling --- colossalai/shardformer/modeling/opt.py | 88 ++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 colossalai/shardformer/modeling/opt.py diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..8b084fcb72ae --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value From a8b9abed244a44b1130d5ebad55aedb48bebda92 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 7 Jul 2023 13:51:19 +0800 Subject: [PATCH 7/9] [shardformer] gpt2 support flash attention --- colossalai/shardformer/modeling/gpt2.py | 89 +++++++++++++++++++ colossalai/shardformer/policies/gpt2.py | 8 +- tests/kit/model_zoo/transformers/gpt.py | 6 +- .../test_model/test_shard_gpt2.py | 5 +- 4 files changed, 102 insertions(+), 6 deletions(-) create mode 100644 colossalai/shardformer/modeling/gpt2.py diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py new file mode 100644 index 000000000000..8a8f983441ee --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2.py @@ -0,0 +1,89 @@ +from typing import Optional, Tuple, Union + +import torch + +__all__ = ['get_gpt2_forward'] + +def get_gpt2_forward(): + + try: + from xformers.ops import memory_efficient_attention as me_attention + from xformers.ops.fmha.attn_bias import LowerTriangularMask + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def gpt2_flash_attention_forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + _, tgt_len, _ = hidden_states.size() + assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = split_heads(query, self.num_heads, self.head_dim) + key = split_heads(key, self.num_heads, self.head_dim) + value = split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + present = (key, value) + else: + present = None + + attn_bias = None + if not self.is_cross_attention: + attn_bias = LowerTriangularMask() + if attention_mask != None: + if attn_bias: + attn_bias.add_bias(attention_mask) + else: + batch_size, _, tgt_len, src_len = attention_mask.size() + attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous() + + attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=value.size(-1) ** -0.5) + attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) + + return outputs + + return gpt2_flash_attention_forward + +def split_heads(tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + +def merge_heads(tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) \ No newline at end of file diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 549cdbf87a80..cd11bdb39f94 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -4,6 +4,7 @@ from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from ..modeling.gpt2 import get_gpt2_forward __all__ = [ 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', @@ -29,7 +30,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model, GPT2Attention policy = {} @@ -106,6 +107,11 @@ def module_policy(self): ], policy=policy, target_key=GPT2Block) + + if self.shard_config.enable_flash_attention: + policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + 'forward': get_gpt2_forward(), + }) return policy def postprocess(self): diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index b9e0310780af..9c4e23b52e31 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -16,8 +16,8 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -33,7 +33,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ee7737687d99..487f702333d2 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -69,10 +69,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From ed5f235decd918e75a41b72c8236e2d1fdc14f9c Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 7 Jul 2023 13:56:25 +0800 Subject: [PATCH 8/9] [shardformer] gpt2 support flash attention --- colossalai/shardformer/modeling/gpt2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8a8f983441ee..3a55431d0453 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -62,8 +62,8 @@ def gpt2_flash_attention_forward( else: batch_size, _, tgt_len, src_len = attention_mask.size() attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous() - - attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=value.size(-1) ** -0.5) + scale = value.size(-1) ** -0.5 + attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) From af98e60a8c5b061e3aa87213a5d777678bd3433d Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 7 Jul 2023 14:04:01 +0800 Subject: [PATCH 9/9] [shardformer] gpt2 support flash attention --- colossalai/shardformer/modeling/gpt2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 3a55431d0453..a9a38bce235f 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -62,8 +62,12 @@ def gpt2_flash_attention_forward( else: batch_size, _, tgt_len, src_len = attention_mask.size() attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous() + scale = value.size(-1) ** -0.5 + if self.scale_attn_by_inverse_layer_idx: + scale = scale * (1 / float(self.layer_idx + 1)) attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) + attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output)