From cb9d34b7da30f2326c96eafd3df02f14545d0d59 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 16:28:28 +0800 Subject: [PATCH 01/11] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 3 +- .../layer/flash_attention/__init__.py | 3 + .../flash_attention/flash_attention_opt.py | 87 +++++++++++ colossalai/shardformer/policies/opt.py | 146 +++++++++++------- colossalai/shardformer/shard/shard_config.py | 12 +- .../test_flash_attention_for_opt.py | 38 +++++ tests/test_shardformer/test_model/_utils.py | 5 +- .../test_model/test_shard_opt.py | 39 ++--- 8 files changed, 241 insertions(+), 92 deletions(-) create mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py create mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py create mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..fef65fc5fa52 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D +from .flash_attention import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm @@ -8,5 +9,5 @@ __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' ] diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py new file mode 100644 index 000000000000..e3ee83b005c4 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/__init__.py @@ -0,0 +1,3 @@ +from .flash_attention_opt import opt_flash_attention_forward + +__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py new file mode 100644 index 000000000000..3f0445078357 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except ImportError: + print("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..725def7f66ab 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,4 +1,10 @@ -from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, + opt_flash_attention_forward, +) from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -29,67 +35,90 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - policy = {} - - if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + base_policy = { + OPTDecoder: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + ]) - return policy + # use flash attention + if self.shard_config.enable_flash_attention: + del base_policy[OPTAttention] + new_item = { + OPTAttention: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]), + } + base_policy.update(new_item) + + return base_policy def postprocess(self): return self.model @@ -107,12 +136,15 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } - if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..0e826d38ddd0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,14 +13,14 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None - enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -34,11 +34,8 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: @@ -50,3 +47,4 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py new file mode 100644 index 000000000000..7f196e86a253 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py @@ -0,0 +1,38 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.opt.modeling_opt import OPTAttention + +from colossalai.shardformer.layer import opt_flash_attention_forward + + +def test_flash_attention_for_opt(): + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 + + # generate input + hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + + opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, + bias=True).to("cuda") + + opt_flash_attention = deepcopy(opt_attention) + setattr(opt_flash_attention, 'forward', + opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) + opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) + flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) + + assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test_flash_attention_for_opt() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..7eff506e0c28 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,13 +3,12 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): +def build_model(model_fn, enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..af25ba6a8c6c 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,9 +6,9 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, + check_state_dict_equal, clear_cache_before_run, parameterize, rerun_if_address_is_in_use, @@ -43,53 +43,44 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_weight = shard_opt_model.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def check_OPTModel(enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() -def check_OPTModel(rank, world_size, port): +def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + check_OPTModel() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(check_OPTModel, 4) + spawn(run_dist, 4) if __name__ == '__main__': From ac179b5c7ea8d11b77711c1eae1b916a72f656fc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:13:19 +0800 Subject: [PATCH 02/11] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 161 +++++++++--------- colossalai/shardformer/shard/shard_config.py | 14 +- tests/test_shardformer/test_model/_utils.py | 9 +- .../test_model/test_shard_opt.py | 38 +++-- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 725def7f66ab..5e1c1cf36ed4 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -35,90 +35,94 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) # use flash attention if self.shard_config.enable_flash_attention: - del base_policy[OPTAttention] - new_item = { - OPTAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]), - } - base_policy.update(new_item) - - return base_policy + policy[OPTAttention] = ModulePolicyDescription( + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="out_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]) + + return policy def postprocess(self): return self.model @@ -136,15 +140,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) - policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0e826d38ddd0..792e9a3f7ad4 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False @@ -34,12 +35,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) def _turn_on_all_optimization(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7eff506e0c28..44a8f841a6d4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,17 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_flash_attention=False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index af25ba6a8c6c..9ca0669340a5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, check_state_dict_equal, @@ -43,44 +44,55 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def check_OPTModel(enable_flash_attention): +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() -def run_dist(rank, world_size, port): +def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_OPTModel() - torch.cuda.empty_cache() + run_opt_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(run_dist, 4) + spawn(check_OPTModel, 4) if __name__ == '__main__': From 00c4a826e2fe0fcc6e652b87fdb80e60e83fdbfe Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:21:08 +0800 Subject: [PATCH 03/11] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 6 ------ colossalai/shardformer/shard/shard_config.py | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5e1c1cf36ed4..7a838a632776 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -79,12 +79,6 @@ def module_policy(self): ]) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 792e9a3f7ad4..4076becdbbc9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -40,6 +40,9 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() def _turn_on_all_optimization(self): """ From 27526d943654b10f75304d3f08e1f02dc8134dfc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 18:22:23 +0800 Subject: [PATCH 04/11] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 2 +- .../layer/flash_attention/__init__.py | 3 --- .../flash_attention/flash_attention_opt.py | 5 +++-- colossalai/shardformer/policies/opt.py | 21 +++---------------- 4 files changed, 7 insertions(+), 24 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index fef65fc5fa52..a994e8f37e91 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,6 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention import opt_flash_attention_forward +from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py deleted file mode 100644 index e3ee83b005c4..000000000000 --- a/colossalai/shardformer/layer/flash_attention/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .flash_attention_opt import opt_flash_attention_forward - -__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py index 3f0445078357..8b084fcb72ae 100644 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -68,8 +68,9 @@ def opt_flash_attention_forward( try: from xformers.ops import memory_efficient_attention as me_attention - except ImportError: - print("Error: xformers module is not installed. Please install it to use flash attention.") + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, key_states, value_states, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 7a838a632776..9ca0373c69da 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -97,24 +97,9 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription( - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="out_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]) + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': opt_flash_attention_forward, + }) return policy From 70535bd398f25660f60a6912856151172fa128d5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:16:46 +0800 Subject: [PATCH 05/11] [shardformer] move to modeling --- colossalai/shardformer/layer/__init__.py | 15 +++- .../flash_attention/flash_attention_opt.py | 88 ------------------- colossalai/shardformer/policies/opt.py | 9 +- .../test_flash_attention_for_opt.py | 38 -------- 4 files changed, 13 insertions(+), 137 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py delete mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index a994e8f37e91..72e002497fa5 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,13 +1,20 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', + 'DropoutForParallelInput', + 'DropoutForReplicatedInput', + "cross_entropy_1d", + 'FusedLayerNorm', + 'FusedRMSNorm', ] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py deleted file mode 100644 index 8b084fcb72ae..000000000000 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional, Tuple - -import torch - -__all__ = ['opt_flash_attention_forward'] - - -def opt_flash_attention_forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) - # get query proj - # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) - query_states = self.q_proj(hidden_states).view(*attention_input_shape) - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") - attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=attention_mask, - p=self.dropout, - scale=self.scaling) - - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9ca0373c69da..fa7bf1befd44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,12 +1,7 @@ -from colossalai.shardformer.layer import ( - FusedLayerNorm, - Linear1D_Col, - Linear1D_Row, - VocabParallelEmbedding1D, - opt_flash_attention_forward, -) +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.opt import opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py deleted file mode 100644 index 7f196e86a253..000000000000 --- a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py +++ /dev/null @@ -1,38 +0,0 @@ -from copy import deepcopy - -import torch -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.opt.modeling_opt import OPTAttention - -from colossalai.shardformer.layer import opt_flash_attention_forward - - -def test_flash_attention_for_opt(): - BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 - - # generate input - hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") - - opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, - bias=True).to("cuda") - - opt_flash_attention = deepcopy(opt_attention) - setattr(opt_flash_attention, 'forward', - opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) - opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) - flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) - - assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) - - -if __name__ == '__main__': - test_flash_attention_for_opt() From 794dd86a6e63e97df35cad6c746c08caace758e6 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:19:51 +0800 Subject: [PATCH 06/11] [shardformer] move to modeling --- colossalai/shardformer/modeling/opt.py | 88 ++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 colossalai/shardformer/modeling/opt.py diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..8b084fcb72ae --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value From 366e90bcfc68d8502b634f076cd32792f57d8c10 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 13 Jul 2023 21:00:48 +0800 Subject: [PATCH 07/11] [shardformer] update gpt2 to use coloattention --- colossalai/shardformer/modeling/gpt2.py | 54 +++++++++++++++---------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index a9a38bce235f..24a19d52ee5e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -4,14 +4,16 @@ __all__ = ['get_gpt2_forward'] + def get_gpt2_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - from xformers.ops.fmha.attn_bias import LowerTriangularMask - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - + # try: + # from xformers.ops import memory_efficient_attention as me_attention + # from xformers.ops.fmha.attn_bias import LowerTriangularMask + # except: + # raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + def gpt2_flash_attention_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -30,8 +32,7 @@ def gpt2_flash_attention_forward( if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) @@ -52,31 +53,39 @@ def gpt2_flash_attention_forward( present = (key, value) else: present = None - - attn_bias = None + if not self.is_cross_attention: - attn_bias = LowerTriangularMask() + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None if attention_mask != None: - if attn_bias: - attn_bias.add_bias(attention_mask) + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal else: - batch_size, _, tgt_len, src_len = attention_mask.size() - attn_bias = attention_mask.expand(batch_size, self.num_heads, tgt_len, src_len).contiguous() - - scale = value.size(-1) ** -0.5 + attn_mask_type = AttnMaskType.padding + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + + scale = value.size(-1)**-0.5 if self.scale_attn_by_inverse_layer_idx: scale = scale * (1 / float(self.layer_idx + 1)) - attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) - - attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) + + attention = ColoAttention(embed_dim=self.embed_dim, + num_heads=self.num_heads, + dropout=self.attn_dropout.p, + scale=scale) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + # attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) + + # attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) return outputs - + return gpt2_flash_attention_forward + def split_heads(tensor, num_heads, attn_head_size): """ Splits hidden_size dim into attn_head_size and num_heads @@ -85,9 +94,10 @@ def split_heads(tensor, num_heads, attn_head_size): tensor = tensor.view(new_shape) return tensor + def merge_heads(tensor, num_heads, attn_head_size): """ Merges attn_head_size dim and num_attn_heads dim into hidden_size """ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) \ No newline at end of file + return tensor.view(new_shape) From 19bd3b5e4d889f6b18fbf613c2c0893d0afe109e Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 13 Jul 2023 21:02:57 +0800 Subject: [PATCH 08/11] [shardformer] update gpt2 to use coloattention --- colossalai/shardformer/modeling/gpt2.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 24a19d52ee5e..6689d36211ce 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -7,11 +7,6 @@ def get_gpt2_forward(): - # try: - # from xformers.ops import memory_efficient_attention as me_attention - # from xformers.ops.fmha.attn_bias import LowerTriangularMask - # except: - # raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention def gpt2_flash_attention_forward( From 73b2ea4abcb7376b2240029458563dd24b93b501 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 13 Jul 2023 21:03:53 +0800 Subject: [PATCH 09/11] [shardformer] update gpt2 to use coloattention --- colossalai/shardformer/modeling/gpt2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 6689d36211ce..9b176264f268 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -69,9 +69,7 @@ def gpt2_flash_attention_forward( scale=scale) attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) - # attn_output = me_attention(query=query, key=key, value=value, attn_bias=attn_bias, p=self.attn_dropout.p, scale=scale) - # attn_output = merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) From 35c40e517debcb3a11fb9492d30c804ad43275c5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 13 Jul 2023 21:04:53 +0800 Subject: [PATCH 10/11] [shardformer] update gpt2 to use coloattention --- colossalai/shardformer/modeling/gpt2.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 9b176264f268..dedf958248ab 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -86,11 +86,3 @@ def split_heads(tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) return tensor - - -def merge_heads(tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) From 01a2269ee4b86cce8edec8982cb4276bb34beb6f Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Sat, 15 Jul 2023 11:24:14 +0800 Subject: [PATCH 11/11] [shardformer] update gpt2 --- colossalai/shardformer/modeling/gpt2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index dedf958248ab..fbb1691ec595 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -63,6 +63,7 @@ def gpt2_flash_attention_forward( if self.scale_attn_by_inverse_layer_idx: scale = scale * (1 / float(self.layer_idx + 1)) + # use coloattention attention = ColoAttention(embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p,