From cb9d34b7da30f2326c96eafd3df02f14545d0d59 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 16:28:28 +0800 Subject: [PATCH 1/7] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 3 +- .../layer/flash_attention/__init__.py | 3 + .../flash_attention/flash_attention_opt.py | 87 +++++++++++ colossalai/shardformer/policies/opt.py | 146 +++++++++++------- colossalai/shardformer/shard/shard_config.py | 12 +- .../test_flash_attention_for_opt.py | 38 +++++ tests/test_shardformer/test_model/_utils.py | 5 +- .../test_model/test_shard_opt.py | 39 ++--- 8 files changed, 241 insertions(+), 92 deletions(-) create mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py create mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py create mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..fef65fc5fa52 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D +from .flash_attention import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm @@ -8,5 +9,5 @@ __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' ] diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py new file mode 100644 index 000000000000..e3ee83b005c4 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/__init__.py @@ -0,0 +1,3 @@ +from .flash_attention_opt import opt_flash_attention_forward + +__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py new file mode 100644 index 000000000000..3f0445078357 --- /dev/null +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except ImportError: + print("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..725def7f66ab 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,4 +1,10 @@ -from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, + opt_flash_attention_forward, +) from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -29,67 +35,90 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - policy = {} - - if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + base_policy = { + OPTDecoder: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + ]) - return policy + # use flash attention + if self.shard_config.enable_flash_attention: + del base_policy[OPTAttention] + new_item = { + OPTAttention: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]), + } + base_policy.update(new_item) + + return base_policy def postprocess(self): return self.model @@ -107,12 +136,15 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } - if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..0e826d38ddd0 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,14 +13,14 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None - enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -34,11 +34,8 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: @@ -50,3 +47,4 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py new file mode 100644 index 000000000000..7f196e86a253 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py @@ -0,0 +1,38 @@ +from copy import deepcopy + +import torch +import torch.nn as nn +from torch.testing import assert_close +from transformers.models.opt.modeling_opt import OPTAttention + +from colossalai.shardformer.layer import opt_flash_attention_forward + + +def test_flash_attention_for_opt(): + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 + + # generate input + hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), + dtype=torch.float32, + device="cuda", + requires_grad=True) + attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + + opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, + bias=True).to("cuda") + + opt_flash_attention = deepcopy(opt_attention) + setattr(opt_flash_attention, 'forward', + opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) + opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) + flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) + + assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + test_flash_attention_for_opt() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..7eff506e0c28 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,13 +3,12 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): +def build_model(model_fn, enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..af25ba6a8c6c 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,9 +6,9 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, + check_state_dict_equal, clear_cache_before_run, parameterize, rerun_if_address_is_in_use, @@ -43,53 +43,44 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_weight = shard_opt_model.decoder.embed_tokens.weight - - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def check_OPTModel(enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() -def check_OPTModel(rank, world_size, port): +def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + check_OPTModel() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(check_OPTModel, 4) + spawn(run_dist, 4) if __name__ == '__main__': From ac179b5c7ea8d11b77711c1eae1b916a72f656fc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:13:19 +0800 Subject: [PATCH 2/7] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 161 +++++++++--------- colossalai/shardformer/shard/shard_config.py | 14 +- tests/test_shardformer/test_model/_utils.py | 9 +- .../test_model/test_shard_opt.py | 38 +++-- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 725def7f66ab..5e1c1cf36ed4 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -35,90 +35,94 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) # use flash attention if self.shard_config.enable_flash_attention: - del base_policy[OPTAttention] - new_item = { - OPTAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]), - } - base_policy.update(new_item) - - return base_policy + policy[OPTAttention] = ModulePolicyDescription( + method_replacement={ + 'forward': opt_flash_attention_forward, + }, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="q_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="k_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="v_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription(suffix="out_proj", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + ]) + + return policy def postprocess(self): return self.model @@ -136,15 +140,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) - policy.update(new_item) return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0e826d38ddd0..792e9a3f7ad4 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False @@ -34,12 +35,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) def _turn_on_all_optimization(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7eff506e0c28..44a8f841a6d4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,17 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_flash_attention=False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True, enable_flash_attention=enable_flash_attention) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index af25ba6a8c6c..9ca0669340a5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, check_state_dict_equal, @@ -43,44 +44,55 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def check_OPTModel(enable_flash_attention): +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() -def run_dist(rank, world_size, port): +def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_OPTModel() - torch.cuda.empty_cache() + run_opt_test() @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(run_dist, 4) + spawn(check_OPTModel, 4) if __name__ == '__main__': From 00c4a826e2fe0fcc6e652b87fdb80e60e83fdbfe Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 17:21:08 +0800 Subject: [PATCH 3/7] [shardformer] opt support flash attention --- colossalai/shardformer/policies/opt.py | 6 ------ colossalai/shardformer/shard/shard_config.py | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5e1c1cf36ed4..7a838a632776 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -79,12 +79,6 @@ def module_policy(self): ]) # optimization configuration - if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 792e9a3f7ad4..4076becdbbc9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -40,6 +40,9 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() def _turn_on_all_optimization(self): """ From 27526d943654b10f75304d3f08e1f02dc8134dfc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 4 Jul 2023 18:22:23 +0800 Subject: [PATCH 4/7] [shardformer] opt support flash attention --- colossalai/shardformer/layer/__init__.py | 2 +- .../layer/flash_attention/__init__.py | 3 --- .../flash_attention/flash_attention_opt.py | 5 +++-- colossalai/shardformer/policies/opt.py | 21 +++---------------- 4 files changed, 7 insertions(+), 24 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/__init__.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index fef65fc5fa52..a994e8f37e91 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,6 +1,6 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention import opt_flash_attention_forward +from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm diff --git a/colossalai/shardformer/layer/flash_attention/__init__.py b/colossalai/shardformer/layer/flash_attention/__init__.py deleted file mode 100644 index e3ee83b005c4..000000000000 --- a/colossalai/shardformer/layer/flash_attention/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .flash_attention_opt import opt_flash_attention_forward - -__all__ = ['opt_flash_attention_forward'] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py index 3f0445078357..8b084fcb72ae 100644 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py @@ -68,8 +68,9 @@ def opt_flash_attention_forward( try: from xformers.ops import memory_efficient_attention as me_attention - except ImportError: - print("Error: xformers module is not installed. Please install it to use flash attention.") + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + attn_output = me_attention(query_states, key_states, value_states, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 7a838a632776..9ca0373c69da 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -97,24 +97,9 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription( - method_replacement={ - 'forward': opt_flash_attention_forward, - }, - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="q_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="k_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="v_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription(suffix="out_proj", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - ]) + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': opt_flash_attention_forward, + }) return policy From 70535bd398f25660f60a6912856151172fa128d5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:16:46 +0800 Subject: [PATCH 5/7] [shardformer] move to modeling --- colossalai/shardformer/layer/__init__.py | 15 +++- .../flash_attention/flash_attention_opt.py | 88 ------------------- colossalai/shardformer/policies/opt.py | 9 +- .../test_flash_attention_for_opt.py | 38 -------- 4 files changed, 13 insertions(+), 137 deletions(-) delete mode 100644 colossalai/shardformer/layer/flash_attention/flash_attention_opt.py delete mode 100644 tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index a994e8f37e91..72e002497fa5 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,13 +1,20 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .flash_attention.flash_attention_opt import opt_flash_attention_forward from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'opt_flash_attention_forward' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', + 'DropoutForParallelInput', + 'DropoutForReplicatedInput', + "cross_entropy_1d", + 'FusedLayerNorm', + 'FusedRMSNorm', ] diff --git a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py b/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py deleted file mode 100644 index 8b084fcb72ae..000000000000 --- a/colossalai/shardformer/layer/flash_attention/flash_attention_opt.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional, Tuple - -import torch - -__all__ = ['opt_flash_attention_forward'] - - -def opt_flash_attention_forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) - # get query proj - # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) - query_states = self.q_proj(hidden_states).view(*attention_input_shape) - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") - attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=attention_mask, - p=self.dropout, - scale=self.scaling) - - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9ca0373c69da..fa7bf1befd44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,12 +1,7 @@ -from colossalai.shardformer.layer import ( - FusedLayerNorm, - Linear1D_Col, - Linear1D_Row, - VocabParallelEmbedding1D, - opt_flash_attention_forward, -) +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.opt import opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ diff --git a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py b/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py deleted file mode 100644 index 7f196e86a253..000000000000 --- a/tests/test_shardformer/test_layer/test_flash_attention/test_flash_attention_for_opt.py +++ /dev/null @@ -1,38 +0,0 @@ -from copy import deepcopy - -import torch -import torch.nn as nn -from torch.testing import assert_close -from transformers.models.opt.modeling_opt import OPTAttention - -from colossalai.shardformer.layer import opt_flash_attention_forward - - -def test_flash_attention_for_opt(): - BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 - - # generate input - hidden_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - key_value_states = torch.randn((BATCH, N_CTX, N_HEADS * D_HEAD), - dtype=torch.float32, - device="cuda", - requires_grad=True) - attention_mask = torch.ones((BATCH, 1, N_CTX, N_CTX), dtype=torch.float32, device="cuda") - - opt_attention = OPTAttention(embed_dim=D_HEAD * N_HEADS, num_heads=N_HEADS, dropout=0, is_decoder=True, - bias=True).to("cuda") - - opt_flash_attention = deepcopy(opt_attention) - setattr(opt_flash_attention, 'forward', - opt_flash_attention_forward.__get__(opt_flash_attention, opt_flash_attention.__class__)) - opt_attention_output = opt_attention(hidden_states, key_value_states, attention_mask=attention_mask) - flash_attention_output = opt_flash_attention(hidden_states, key_value_states, attention_mask=attention_mask) - - assert_close(flash_attention_output[0], opt_attention_output[0], atol=1e-5, rtol=1e-5) - - -if __name__ == '__main__': - test_flash_attention_for_opt() From 794dd86a6e63e97df35cad6c746c08caace758e6 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 11:19:51 +0800 Subject: [PATCH 6/7] [shardformer] move to modeling --- colossalai/shardformer/modeling/opt.py | 88 ++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 colossalai/shardformer/modeling/opt.py diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..8b084fcb72ae --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value From ecced2af0f1c814b79b80d738ec23681d72e487e Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 10 Jul 2023 18:23:11 +0800 Subject: [PATCH 7/7] [shardformer] bert support flash attention --- colossalai/shardformer/modeling/bert.py | 113 ++++++++++++++++++ colossalai/shardformer/policies/bert.py | 10 +- tests/kit/model_zoo/transformers/bert.py | 14 +-- .../test_model/test_shard_bert.py | 5 +- 4 files changed, 132 insertions(+), 10 deletions(-) create mode 100644 colossalai/shardformer/modeling/bert.py diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py new file mode 100644 index 000000000000..faa579529ce4 --- /dev/null +++ b/colossalai/shardformer/modeling/bert.py @@ -0,0 +1,113 @@ +from typing import Optional, Tuple +import math + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import functional as F + +__all__ = ['get_bert_forward'] + +def get_bert_forward(): + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + def bert_flash_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + final_attention_mask = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + final_attention_mask = relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + final_attention_mask = relative_position_scores_query + relative_position_scores_key + + scale = 1 / math.sqrt(self.attention_head_size) + if attention_mask is not None: + if final_attention_mask != None: + final_attention_mask = final_attention_mask * scale + attention_mask + else: + final_attention_mask = attention_mask + batch_size, src_len = query_layer.size()[0], query_layer.size()[2] + tgt_len = key_layer.size()[2] + final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len) + + query_layer = query_layer.permute(0, 2, 1, 3).contiguous() + key_layer = key_layer.permute(0, 2, 1, 3).contiguous() + value_layer = value_layer.permute(0, 2, 1, 3).contiguous() + + context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, None) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + return bert_flash_attention_forward \ No newline at end of file diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 9c2736cc64d3..a213538f8809 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,6 +4,7 @@ from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from ..modeling.bert import get_bert_forward __all__ = [ 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', @@ -31,7 +32,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer + from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertSelfAttention policy = {} @@ -120,6 +121,13 @@ def module_policy(self): )], policy=policy, target_key=BertEmbeddings) + + # use flash attention + if self.shard_config.enable_flash_attention: + policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_bert_forward(), + }) + return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index d2d3de7b7bee..248a4afe27e8 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -20,7 +20,7 @@ def data_gen(): # token_type_ids = tokenized_input['token_type_ids'] input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) @@ -69,19 +69,19 @@ def data_gen_for_mcq(): # data['labels'] = torch.tensor([0], dtype=torch.int64) input_ids = torch.tensor([[[ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, - 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102 ], [ 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, - 2218, 1999, 1996, 2192, 1012, 102, 0 + 2218, 1999, 1996, 2192, 1012, 102, 0, 0 ]]]) token_type_ids = torch.tensor( - [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]]) attention_mask = torch.tensor( - [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]]) labels = torch.tensor([0], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 1afedb7079ea..865ba2fa7126 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -69,10 +69,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()