Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion colossalai/nn/layer/base_layer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from contextlib import contextmanager

import torch.nn as nn

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from contextlib import contextmanager


class ParallelLayer(nn.Module):

global_state_dict: bool = True

def __init__(self):
Expand Down
133 changes: 126 additions & 7 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce

output = torch.matmul(input_, weight.t())
Expand All @@ -74,12 +74,13 @@ def backward(ctx, grad_output):
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
Expand All @@ -93,5 +94,123 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None


def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.

Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication

"""

@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _split(input_, dim, process_group)

@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None


class _ReduceInput(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.

Args:
input_: input matrix.
parallel_mode: parallel mode.
"""

@staticmethod
def forward(ctx, input_, process_group):
return _reduce(input_, process_group)

@staticmethod
def backward(ctx, grad_output):
return grad_output, None


def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
dist.all_reduce(input_, group=process_group)
return input_


def _split(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_

# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'

tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].contiguous()

return output


def _gather(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_

# all gather
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)

# concat
output = torch.cat(tensor_list, dim=dim).contiguous()

return output


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.

Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""

@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)

@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None


def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)


def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)


def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)


def reduce_input(input_, process_group):
return _ReduceInput.apply(input_, process_group)
54 changes: 8 additions & 46 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,20 @@
import os
from contextlib import contextmanager

import torch
import torch.distributed as dist
import torch.nn as nn


class SeedManager:
"""
This class is a random state manager to change random state for different random seed.

"""

def __init__(self):
original_state = torch.cuda.get_rng_state()
# TODO: unify this seed manager with the colossalai.context.random
seed = os.getpid()
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)

def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)

def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state

@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.

Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)


_seed_manager = SeedManager()
from .utils import create_randomizer_with_offset


class Dropout1D(nn.Dropout):

def __init__(self, p=0.5, inplace=False):
def __init__(self, p=0.5, inplace=False, process_group=None):
super().__init__(p, inplace)

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)

def forward(self, input):
with _seed_manager.dropout_mode():
with self.randomizer.fork_rng():
input = super().forward(input)
return input
Loading