Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
gather_output: bool = True,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
fp8_communication: bool = False,
*args,
**kwargs,
):
Expand All @@ -81,6 +82,7 @@ def __init__(
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
self.fp8_communication = fp8_communication

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
Expand Down Expand Up @@ -155,7 +157,9 @@ def _fill_padding_idx_with_zero(self) -> None:
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output
else:
return output_parallel
Expand Down Expand Up @@ -274,6 +278,7 @@ def __init__(
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
*args,
**kwargs,
):
Expand All @@ -282,6 +287,7 @@ def __init__(
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
self.fp8_communication = fp8_communication

tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
Expand Down Expand Up @@ -390,5 +396,5 @@ def forward(self, input_: Tensor) -> Tensor:
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_forward(embedding_output, self.process_group)
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
return output
2 changes: 2 additions & 0 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def __init__(
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
**kwargs,
):
# create weight and bias
Expand Down Expand Up @@ -602,6 +603,7 @@ def __init__(
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
fp8_communication=fp8_communication,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)
Expand Down
6 changes: 5 additions & 1 deletion colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ def __init__(
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
Expand All @@ -638,6 +639,7 @@ def __init__(
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication

if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
Expand Down Expand Up @@ -767,7 +769,9 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:

if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel

Expand Down
30 changes: 24 additions & 6 deletions colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,17 @@ def bert_model_forward(
if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
encoder_hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
Expand Down Expand Up @@ -242,7 +248,10 @@ def custom_forward(*inputs):
if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

if output_hidden_states:
Expand Down Expand Up @@ -1135,11 +1144,17 @@ def forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward(
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
encoder_hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

encoder_outputs = self.encoder(
Expand All @@ -1159,7 +1174,10 @@ def forward(

# When sequence parallelism done, gather the output tensor in forward and split it in backward
sequence_output = gather_forward_split_backward(
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
sequence_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
Expand Down
20 changes: 16 additions & 4 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def bloom_model_forward(
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

start_idx, end_idx = stage_index[0], stage_index[1]
Expand Down Expand Up @@ -264,7 +267,10 @@ def bloom_model_forward(
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

if stage_manager.is_last_stage():
Expand Down Expand Up @@ -922,7 +928,10 @@ def forward(
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
Expand Down Expand Up @@ -960,7 +969,10 @@ def forward(

# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
Expand Down
37 changes: 33 additions & 4 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ def chatglm_model_forward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
Expand Down Expand Up @@ -245,13 +247,15 @@ def chatglm_model_forward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -414,13 +418,15 @@ def forward(
inputs_embeds,
dim=0,
process_group=sp_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
inputs_embeds,
dim=0,
process_group=sp_group,
grad_scale=1 / sp_size,
fp8_communication=shard_config.fp8_communication,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
Expand All @@ -436,13 +442,15 @@ def forward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
fp8_communication=shard_config.fp8_communication,
)

if not return_dict:
Expand Down Expand Up @@ -532,9 +540,24 @@ def forward(
key_layer = key_layer.reshape(sq, bs, -1)
value_layer = value_layer.reshape(sq, bs, -1)

query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
query_layer = all_to_all_comm(
query_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
key_layer = all_to_all_comm(
key_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
value_layer = all_to_all_comm(
value_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)

query_layer = query_layer.view(
sq * sp_size,
Expand Down Expand Up @@ -610,7 +633,13 @@ def forward(

context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
if sp_mode == "all_to_all":
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
context_layer = all_to_all_comm(
context_layer,
sp_group,
gather_dim=2,
scatter_dim=0,
fp8_communication=shard_config.fp8_communication,
)

# =================
# Output. [sq, b, h]
Expand Down
30 changes: 22 additions & 8 deletions colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,15 @@ def command_model_forward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)

# decoder layers
Expand Down Expand Up @@ -211,13 +213,15 @@ def command_model_forward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)

# add hidden states from the last decoder layer
Expand Down Expand Up @@ -382,9 +386,9 @@ def forward(

# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -446,7 +450,9 @@ def forward(
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

Expand Down Expand Up @@ -526,9 +532,13 @@ def forward(
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds

# decoder layers
Expand Down Expand Up @@ -573,9 +583,13 @@ def forward(
hidden_states = self.norm(hidden_states)

if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)

# add hidden states from the last decoder layer
if output_hidden_states:
Expand Down
Loading