diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..72e002497fa5 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -6,7 +6,15 @@ from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', + 'DropoutForParallelInput', + 'DropoutForReplicatedInput', + "cross_entropy_1d", + 'FusedLayerNorm', + 'FusedRMSNorm', ] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py new file mode 100644 index 000000000000..8b084fcb72ae --- /dev/null +++ b/colossalai/shardformer/modeling/opt.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +import torch + +__all__ = ['opt_flash_attention_forward'] + + +def opt_flash_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) + # get query proj + # query_states = self._shape(self.q_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(*attention_input_shape) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) + value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*attention_input_shape) + value_states = self.v_proj(key_value_states).view(*attention_input_shape) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*attention_input_shape) + value_states = self.v_proj(hidden_states).view(*attention_input_shape) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + if layer_head_mask != None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + if attention_mask != None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous() + + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + + attn_output = me_attention(query_states, + key_states, + value_states, + attn_bias=attention_mask, + p=self.dropout, + scale=self.scaling) + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index b87db53f45f1..fa7bf1befd44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,6 +1,7 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ +from ..modeling.opt import opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -89,6 +90,12 @@ def module_policy(self): policy=policy, target_key=OPTDecoderLayer) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + 'forward': opt_flash_attention_forward, + }) + return policy def postprocess(self): @@ -107,12 +114,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), policy=policy, target_key=OPTForCausalLM) + return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 83c08d275df3..4076becdbbc9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -21,6 +21,7 @@ class ShardConfig: enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_flash_attention: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -39,7 +40,6 @@ def __post_init__(self): else: # get the parallel size self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) - # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: self._turn_on_all_optimization() @@ -50,3 +50,4 @@ def _turn_on_all_optimization(self): """ # you can add all the optimization flag here self.enable_fused_normalization = True + self.enable_flash_attention = True diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d83d9ecd39e0..44a8f841a6d4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,13 +3,17 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False): # create new model org_model = model_fn().cuda() # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25bccb13b1a8..9ca0669340a5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -9,6 +9,7 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, + check_state_dict_equal, clear_cache_before_run, parameterize, rerun_if_address_is_in_use, @@ -71,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() @@ -82,7 +85,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() + run_opt_test() @pytest.mark.dist