diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py new file mode 100644 index 000000000000..d91e7c73b52b --- /dev/null +++ b/colossalai/shardformer/modeling/chatglm.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F + + +def get_flash_core_attention_forward(): + + from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention + + from .chatglm2_6b.modeling_chatglm import CoreAttention + + def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + query_layer = query_layer.permute(1, 0, 2, 3).contiguous() + key_layer = key_layer.permute(1, 0, 2, 3).contiguous() + value_layer = value_layer.permute(1, 0, 2, 3).contiguous() + + scale = 1.0 / self.norm_factor + if self.coeff is not None: + scale = scale * self.coeff + + flash_attention_mask = None + attn_mask_type = None + if attention_mask is None: + attn_mask_type = AttnMaskType.causal + else: + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.hidden_size_per_partition, + num_heads=self.num_attention_heads_per_partition, + dropout=self.attention_dropout.p, + scale=scale) + context_layer = attention(query_layer, + key_layer, + value_layer, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type) + + context_layer = context_layer.permute(1, 0, -1).contiguous() + + return context_layer + + return forward + + +def get_jit_fused_glm_block_forward(): + + from .chatglm2_6b.modeling_chatglm import GLMBlock + + def forward( + self: GLMBlock, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training) + + return output, kv_cache + + return forward diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py similarity index 100% rename from tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py rename to colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py similarity index 100% rename from tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py rename to colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 46aa3b52af8f..6054e923f35c 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -4,6 +4,8 @@ import colossalai.shardformer.layer as col_nn +from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.jit import get_jit_fused_dropout_add_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] @@ -26,7 +28,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock + from ..modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock policy = {} @@ -107,6 +109,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=ChatGLMModel) + # use flash attention + if self.shard_config.enable_flash_attention: + policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_flash_core_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_glm_block_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + return policy def postprocess(self): diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 04e73a832abe..2d94deac29de 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -1,9 +1,10 @@ import torch import transformers +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + from ..registry import ModelAttribute, model_zoo -from .chatglm2_6b.configuration_chatglm import ChatGLMConfig -from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel # ================================ # Register single-sentence ChatGLM diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index a0fa4bd82e74..63a6dd54a9e2 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -72,7 +72,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) -def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): +@parameterize('enable_flash_attention', [True, False]) +@parameterize('enable_jit_fused', [True, False]) +def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create new model @@ -80,7 +82,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm":