Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions colossalai/shardformer/layer/dist_crossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.distributed import ProcessGroup


class DistCrossEntropy(Function):
Expand All @@ -14,7 +15,7 @@ class DistCrossEntropy(Function):
"""

@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
Expand All @@ -34,15 +35,15 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index:
"""
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)

# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)

# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank()
world_size = dist.get_world_size()
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
global_vocab_size = partition_vocab_size * world_size

# [down, up) => false, other device and -100 => true
Expand All @@ -67,11 +68,11 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index:
pred_logits[mask] = 0.0

# allreduce the get all x(i,y)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
Expand Down Expand Up @@ -101,5 +102,8 @@ def backward(ctx, grad_output):
return grad_logits, None, None


def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)
def cross_entropy_1d(vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
process_group: ProcessGroup = None) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
32 changes: 28 additions & 4 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,43 @@
from typing import List, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup

from .layers import ParallelModule
from .utils import create_randomizer_with_offset


class Dropout1D(nn.Dropout):
class Dropout1D(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
and applied on the same position of different ranks, leading to poor convergence performance.

Args:
p (float): probability of an element to be zeroed. Defaults to 0.5.
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
"""

def __init__(self, p=0.5, inplace=False, process_group=None):
super().__init__(p, inplace)
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
# init with nn.Dropout
super(nn.Dropout, self).__init__(p=p, inplace=inplace)

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)

@staticmethod
def from_native_module(module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D":
"""
Create a Dropout1D layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return Dropout1D(p=p, inplace=inplace, process_group=process_group)

def forward(self, input):
with self.randomizer.fork_rng():
input = super().forward(input)
Expand Down
Loading