Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
overlap: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
Expand All @@ -87,6 +88,7 @@ def __init__(self,
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
Expand Down Expand Up @@ -190,7 +192,8 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1, self.overlap)
self.process_group, True,
self.seq_parallel_dim, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)

Expand Down Expand Up @@ -236,6 +239,7 @@ def __init__(self,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
Expand All @@ -254,6 +258,7 @@ def __init__(self,
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)

if skip_bias_add and not bias:
Expand Down Expand Up @@ -390,7 +395,8 @@ def forward(self, input_: Tensor) -> Tensor:
else:
output_parallel = F.linear(input_, self.weight)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
self.seq_parallel_dim)
else:
output = reduce_forward(output_parallel, self.process_group)

Expand Down
135 changes: 118 additions & 17 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
Expand Down Expand Up @@ -146,6 +148,7 @@ def chatglm_model_forward(
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
logger = logging.get_logger(__name__)
output_hidden_states = (output_hidden_states
Expand Down Expand Up @@ -198,6 +201,11 @@ def chatglm_model_forward(
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]

if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
Expand All @@ -214,6 +222,11 @@ def chatglm_model_forward(
hidden_states, kv_cache = layer_ret
if use_cache:
presents = presents + (kv_cache,)

if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
Expand All @@ -233,23 +246,22 @@ def chatglm_model_forward(
return {'hidden_states': hidden_states}

@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None):
logger = logging.get_logger(__name__)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
Expand All @@ -266,6 +278,7 @@ def chatglm_for_conditional_generation_forward(
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
Expand Down Expand Up @@ -296,3 +309,91 @@ def chatglm_for_conditional_generation_forward(
)
else:
return transformer_outputs


def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):

def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)

if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

# Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward(inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)

hidden_states = gather_forward_split_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)

if not return_dict:
return tuple(v for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)

return forward
18 changes: 12 additions & 6 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,26 @@ def module_policy(self):

# use flash attention
if self.shard_config.enable_flash_attention:
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_bert_flash_attention_forward(),
})
},
policy=policy,
target_key=BertSelfAttention)

# use jit operator
if self.shard_config.enable_jit_fused:
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BertOutput] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=BertSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=BertOutput)

return policy

Expand Down
23 changes: 14 additions & 9 deletions colossalai/shardformer/policies/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,26 @@ def module_policy(self):

# use flash attention
if self.shard_config.enable_flash_attention:
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_blip2_flash_attention_forward(),
})
},
policy=policy,
target_key=Blip2Attention)

# use jit operator
if self.shard_config.enable_jit_fused:
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
method_replacement={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=Blip2QFormerOutput)

return policy

Expand Down
26 changes: 17 additions & 9 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,33 @@ def module_policy(self):
target_key=BloomModel)

if self.shard_config.enable_flash_attention:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_bloom_flash_attention_forward(),
'dropout_add': get_dropout_add_func()
})
'dropout_add': get_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention)

# enable jit fused operator
if self.shard_config.enable_jit_fused:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomMLP] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=BloomAttention)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomGelu] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=BloomMLP)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
})
},
policy=policy,
target_key=BloomGelu)

return policy

Expand Down
Loading