From f4979b9f3b49ccba5785534bcd40ad79f429f44c Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 22 Aug 2023 01:38:14 +0800 Subject: [PATCH 01/10] [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel --- colossalai/shardformer/layer/linear.py | 8 +- colossalai/shardformer/modeling/chatglm.py | 135 +++++++++++++++++--- colossalai/shardformer/policies/chatglm.py | 102 +++++++++------ tests/kit/model_zoo/transformers/chatglm.py | 4 +- 4 files changed, 191 insertions(+), 58 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 69ac3ad2581a..53994ba490d9 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -74,6 +74,7 @@ def __init__(self, process_group: ProcessGroup = None, gather_output: bool = False, seq_parallel: bool = False, + seq_parallel_dim: int = 1, overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -87,6 +88,7 @@ def __init__(self, self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -190,7 +192,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1, self.overlap) + self.process_group, True, self.seq_parallel_dim, self.overlap) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -236,6 +238,7 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, seq_parallel: bool = False, + seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -254,6 +257,7 @@ def __init__(self, self.skip_bias_add = skip_bias_add self.process_group = process_group self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -390,7 +394,7 @@ def forward(self, input_: Tensor) -> Tensor: else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, self.seq_parallel_dim) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index 409e2e1f5497..16dcf87c8cfc 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -9,6 +9,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -146,6 +148,7 @@ def chatglm_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) output_hidden_states = (output_hidden_states @@ -198,6 +201,11 @@ def chatglm_model_forward( all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] + + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -214,6 +222,11 @@ def chatglm_model_forward( hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) + + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -233,23 +246,22 @@ def chatglm_model_forward( return {'hidden_states': hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): + def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None): logger = logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) @@ -266,6 +278,7 @@ def chatglm_for_conditional_generation_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -296,3 +309,91 @@ def chatglm_for_conditional_generation_forward( ) else: return transformer_outputs + + +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] + inputs_embeds = split_forward_gather_backward(inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index e6b458936637..22a212c7a316 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -15,7 +15,11 @@ GLMBlock, ) -from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm import ( + get_chatglm_sequence_parallel_forward_fn, + get_flash_core_attention_forward, + get_jit_fused_glm_block_forward, +) from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -45,8 +49,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ SubModuleReplacementDescription( @@ -55,36 +59,42 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) ]) - policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.core_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription(suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: @@ -123,17 +133,29 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=ChatGLMModel) # use flash attention + if self.shard_config.enable_flash_attention: - policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), - }) + }, + policy=policy, + target_key=CoreAttention) + + # use sequence parallel + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=ChatGLMModel) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_glm_block_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=GLMBlock) return policy @@ -178,7 +200,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index c6473ee2a025..d543df00bdfa 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -12,8 +12,8 @@ def data_gen(): - input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask) From 248ea6b38df6a6808ff86f935ae7a09b89b459e7 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 22 Aug 2023 19:31:44 +0800 Subject: [PATCH 02/10] fix fix fix fix --- colossalai/shardformer/layer/linear.py | 6 +++-- colossalai/shardformer/policies/bert.py | 18 +++++++++----- colossalai/shardformer/policies/blip2.py | 23 +++++++++++------- colossalai/shardformer/policies/bloom.py | 26 ++++++++++++++------- colossalai/shardformer/policies/chatglm2.py | 1 - colossalai/shardformer/policies/gpt2.py | 6 +++-- colossalai/shardformer/policies/llama.py | 6 +++-- colossalai/shardformer/policies/sam.py | 12 ++++++---- colossalai/shardformer/policies/vit.py | 12 ++++++---- 9 files changed, 71 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 53994ba490d9..81c3f973fd49 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -192,7 +192,8 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, self.seq_parallel_dim, self.overlap) + self.process_group, True, + self.seq_parallel_dim, self.overlap) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -394,7 +395,8 @@ def forward(self, input_: Tensor) -> Tensor: else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, self.seq_parallel_dim) + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, + self.seq_parallel_dim) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fe091c658682..19dd95fd6b6a 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -155,20 +155,26 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bert_flash_attention_forward(), - }) + }, + policy=policy, + target_key=BertSelfAttention) # use jit operator if self.shard_config.enable_jit_fused: - policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_self_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BertOutput] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BertSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=BertOutput) return policy diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 3610e2c4109b..2e5388ab0490 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -285,21 +285,26 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_blip2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=Blip2Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( - method_replacement={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_blip2_QFormer_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=Blip2QFormerOutput) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2727272d0867..21db13f6e441 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -125,25 +125,33 @@ def module_policy(self): target_key=BloomModel) if self.shard_config.enable_flash_attention: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func() - }) + 'dropout_add': get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention) # enable jit fused operator if self.shard_config.enable_jit_fused: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomAttention) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomMLP) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_gelu_forward(), 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }) + }, + policy=policy, + target_key=BloomGelu) return policy diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 446e1618f31d..b0d684a67dce 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -133,7 +133,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=ChatGLMModel) # use flash attention - if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d34c0ae9fe64..acae2630942b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -118,9 +118,11 @@ def module_policy(self): target_key=GPT2Block) if self.shard_config.enable_flash_attention: - policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_gpt2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=GPT2Attention) if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be8fa..ccf7764079a9 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -105,9 +105,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaModel) if self.shard_config.enable_flash_attention: - policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_llama_flash_attention_forward(), - }) + }, + policy=policy, + target_key=LlamaAttention) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index b1eba0432b49..9753d5a737b9 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -199,12 +199,16 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[SamAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_sam_flash_attention_forward(), - }) - policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=SamAttention) + self.append_or_create_method_replacement(description={ 'forward': get_sam_vision_flash_attention_forward(), - }) + }, + policy=policy, + target_key=SamVisionAttention) return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 617720ee7950..757bab95f273 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -90,16 +90,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # use flash attention if self.shard_config.enable_flash_attention: - policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_vit_flash_self_attention_forward(), - }) + }, + policy=policy, + target_key=ViTSelfAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[ViTOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_vit_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=ViTOutput) return policy def new_model_class(self): From 3cdfd7c475181f3987cb1166e33714e56f6e1013 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 10:20:55 +0800 Subject: [PATCH 03/10] [shardformer] jit fused fix --- colossalai/shardformer/policies/opt.py | 12 ++++-- colossalai/shardformer/policies/whisper.py | 18 ++++---- .../test_model/test_shard_whisper.py | 41 +++++++++++++++++-- 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ba6036bd0658..ebc2aad37e44 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -100,16 +100,20 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_opt_flash_attention_forward(), - }) + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_opt_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index a33f929f1e48..03b44913730b 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -198,20 +198,20 @@ def module_policy(self): # enable flash attention if self.shard_config.enable_flash_attention: - policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_whisper_flash_attention_forward(), - }) + }, + policy=policy, + target_key=WhisperAttention) - # use jit fused operator + # use jit fused operator, fix WhisperEncoderLayer enable jit fused. if self.shard_config.enable_jit_fused: - policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_whisper_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=WhisperDecoderLayer) return policy diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 90e007e34de8..6445b314dc97 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -82,8 +82,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() org_optimizer.step() @@ -99,7 +99,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, tp_group, atol=atol, rtol=rtol, - dim=0, + dim=1, verbose=False) check_weight(whisper, sharded_whisper, @@ -155,12 +155,39 @@ def run_whisper_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_whisper_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_whisper(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_whisper_test() +def check_whisper_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_whisper_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -168,5 +195,13 @@ def test_whisper(): spawn(check_whisper, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper_3d(): + spawn(check_whisper_3d, 8) + + if __name__ == "__main__": test_whisper() + test_whisper_3d() From bbbb31624259d5067fbe2a2e1dd99422f0742022 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 10:47:56 +0800 Subject: [PATCH 04/10] [shardformer] jit fused fix --- colossalai/shardformer/policies/whisper.py | 18 ------------------ .../test_model/test_shard_whisper.py | 19 +++++++++---------- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 03b44913730b..96c805eb1995 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -33,7 +33,6 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - # TODO: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: @@ -196,23 +195,6 @@ def module_policy(self): policy=policy, target_key=WhisperDecoder) - # enable flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_whisper_flash_attention_forward(), - }, - policy=policy, - target_key=WhisperAttention) - - # use jit fused operator, fix WhisperEncoderLayer enable jit fused. - if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_whisper_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=WhisperDecoderLayer) - return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6445b314dc97..16f831f7027a 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 3e-4, 3e-4 else: atol, rtol = 5e-3, 5e-3 @@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights and gradients if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 3e-4, 3e-4 else: atol, rtol = 5e-3, 5e-3 @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 3e-4, 3e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -195,13 +195,12 @@ def test_whisper(): spawn(check_whisper, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_whisper_3d(): - spawn(check_whisper_3d, 8) - +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_whisper_3d(): +# spawn(check_whisper_3d, 8) if __name__ == "__main__": test_whisper() - test_whisper_3d() + # test_whisper_3d() From 7793325881be0dae448f5564300b59daaad0b02a Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 11:17:57 +0800 Subject: [PATCH 05/10] [shardformer] jit fused fix --- colossalai/shardformer/policies/whisper.py | 12 ++++++++++++ .../test_model/test_shard_whisper.py | 19 ++++++++++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 96c805eb1995..7f4d62215ed9 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -195,6 +195,18 @@ def module_policy(self): policy=policy, target_key=WhisperDecoder) + # enable flash attention + # if self.shard_config.enable_flash_attention: + # policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + # 'forward': get_whisper_flash_attention_forward(), + # }) + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement(description={ + 'forward': get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperAttention) + return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 16f831f7027a..441518719e88 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': - atol, rtol = 3e-4, 3e-4 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights and gradients if test_config['precision'] == 'fp32': - atol, rtol = 3e-4, 3e-4 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 3e-4, 3e-4 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -195,12 +195,13 @@ def test_whisper(): spawn(check_whisper, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_whisper_3d(): -# spawn(check_whisper_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper_3d(): + spawn(check_whisper_3d, 8) + if __name__ == "__main__": test_whisper() - # test_whisper_3d() + test_whisper_3d() From c0b4ed5b68d2ca9faf74faea567aef80234f9f1a Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 11:20:58 +0800 Subject: [PATCH 06/10] [shardformer] jit fused fix --- colossalai/shardformer/policies/whisper.py | 4 ---- tests/test_shardformer/test_model/test_shard_whisper.py | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 7f4d62215ed9..ed087fe8dcd8 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -196,10 +196,6 @@ def module_policy(self): target_key=WhisperDecoder) # enable flash attention - # if self.shard_config.enable_flash_attention: - # policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ - # 'forward': get_whisper_flash_attention_forward(), - # }) if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement(description={ 'forward': get_whisper_flash_attention_forward(), diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 441518719e88..c10638d15fcd 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # TODO(jianghai) fix fp16 +#TODO fix jit fused operator with WhisperForConditionalGeneration @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, From acce981b8a22830d8b72d64324804bb6f68aab42 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 13:40:01 +0800 Subject: [PATCH 07/10] [shardformer] jit fused fix --- .../test_model/test_shard_whisper.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index c10638d15fcd..be8ca1e6dcd0 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -156,39 +156,12 @@ def run_whisper_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, -]) -def run_whisper_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - - clear_layout_converter() - torch.cuda.empty_cache() - - def check_whisper(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_whisper_test() -def check_whisper_3d(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_whisper_3d_test() - - @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -196,13 +169,5 @@ def test_whisper(): spawn(check_whisper, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_whisper_3d(): - spawn(check_whisper_3d, 8) - - if __name__ == "__main__": test_whisper() - test_whisper_3d() From a9f299c39d2bd6164bfaead2fd8544710d8056df Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 16:55:02 +0800 Subject: [PATCH 08/10] [shardformer] jit fused fix --- colossalai/shardformer/modeling/bert.py | 3 +++ colossalai/shardformer/policies/llama.py | 5 +++++ colossalai/shardformer/policies/t5.py | 5 +++++ colossalai/shardformer/policies/vit.py | 5 +++++ colossalai/shardformer/policies/whisper.py | 6 ++++++ 5 files changed, 24 insertions(+) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index d88661953a29..825a9fbde169 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1241,6 +1241,9 @@ def forward( embedding_output = split_forward_gather_backward(embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) encoder_outputs = self.encoder( embedding_output, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ccf7764079a9..c417e5d017bd 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -35,6 +36,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement={ diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 651883d35b87..192a1b8472fc 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Optional, Tuple @@ -59,6 +60,10 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 757bab95f273..b4fb8692e684 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Dict, List, Union import torch.nn as nn @@ -32,6 +33,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index ed087fe8dcd8..9ddf0dc83c4c 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -51,6 +52,11 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": From 8cb81cc19bbfb6e610072cd8e98edc893a9838c5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 23 Aug 2023 18:01:20 +0800 Subject: [PATCH 09/10] [shardformer] jit fused fix --- colossalai/shardformer/modeling/bert.py | 3 +++ tests/test_shardformer/test_model/test_shard_whisper.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 825a9fbde169..30855a622adb 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -187,6 +187,9 @@ def bert_model_forward( hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index be8ca1e6dcd0..8c1cb3c42096 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -114,7 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # TODO(jianghai) fix fp16 -#TODO fix jit fused operator with WhisperForConditionalGeneration +#TODO fix WhisperForConditionalGeneration enable jit fused operator @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, From 78f7af2b443d94ee54453184ebbd4c7e41621706 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 24 Aug 2023 10:53:57 +0800 Subject: [PATCH 10/10] activate checks --- colossalai/shardformer/policies/whisper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 9ddf0dc83c4c..bffb624d0d1a 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -56,6 +56,9 @@ def module_policy(self): self.shard_config.enable_sequence_parallelism = False warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={