Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,7 @@ def configure(
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
Expand All @@ -18,6 +18,7 @@
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
Expand Down
45 changes: 44 additions & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss

__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
from colossalai.shardformer.shard import ShardConfig

__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]


class DistCrossEntropy(Function):
Expand Down Expand Up @@ -132,3 +135,43 @@ def cross_entropy_1d(
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)


def dist_cross_entropy(
labels: torch.Tensor,
logits: torch.Tensor,
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Helper to compute cross entropy loss for most shardformer models,
compatible with PP, TP and SP.
"""
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
)
else:
# NOTE if use TP and not parallel_output, the output is gathered.
# see VocabParallelLMHead1D
shift_logits = shift_logits.view(-1, vocab_size)
loss = loss_fct(shift_logits, shift_labels)

return loss
56 changes: 13 additions & 43 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -359,30 +359,14 @@ def bloom_for_causal_lm_forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous()

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.transformer.dtype,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -1040,24 +1024,10 @@ def forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down
55 changes: 12 additions & 43 deletions colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
Expand All @@ -25,7 +24,7 @@
)
from colossalai.shardformer.shard import ShardConfig

from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy


class CommandPipelineForwards:
Expand Down Expand Up @@ -300,29 +299,9 @@ def command_for_causal_lm_forward(
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -658,24 +637,14 @@ def forward(
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.model.dtype,
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
47 changes: 7 additions & 40 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -372,27 +372,9 @@ def gpt2_lmhead_model_forward(

hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down Expand Up @@ -1282,24 +1264,9 @@ def forward(
hidden_states = transformer_outputs[0]

lm_logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
Loading