Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ We will follow this roadmap to develop Shardformer:
- [x] BERT
- [x] T5
- [x] LlaMa
- [ ] GPT2
- [ ] BLOOM
- [x] GPT2
- [x] OPT
- [x] BLOOM
- [ ] GLM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
Expand All @@ -96,7 +98,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] SwinTransformer
- [ ] SwinTransformer V2
- [ ] Audio
- [ ] To be added
- [ ] Whisper
- [ ] Multi-modal
- [ ] To be added

Expand Down
9 changes: 5 additions & 4 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .dropout import Dropout1D
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm'
]
45 changes: 41 additions & 4 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ['Dropout1D']
__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput']


class Dropout1D(ParallelModule, nn.Dropout):
class DropoutForParallelInput(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
Expand All @@ -32,13 +32,50 @@ def __init__(self, p: float = 0.5, inplace: bool = False, process_group: Process

@staticmethod
def from_native_module(module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput":
"""
Create a DropoutForParallelInput layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group)

def forward(self, input):
with self.randomizer.fork_rng():
input = super().forward(input)
return input


class DropoutForReplicatedInput(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
and applied on the same position of different ranks, leading to poor convergence performance.

Args:
p (float): probability of an element to be zeroed. Defaults to 0.5.
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
"""

def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
# init with nn.Dropout
super(nn.Dropout, self).__init__(p=p, inplace=inplace)

# offset the seed with randomizer index only
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False)

@staticmethod
def from_native_module(
module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput":
"""
Create a Dropout1D layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return Dropout1D(p=p, inplace=inplace, process_group=process_group)
return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group)

def forward(self, input):
with self.randomizer.fork_rng():
Expand Down
8 changes: 5 additions & 3 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)

@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
Expand All @@ -289,9 +290,10 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)

origin_device = self.bias.device
self.bias = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)

def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,25 @@
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']

# ====================================
# For GPT Only
# ====================================

def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):

def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
n_fused: int,
process_group: ProcessGroup,
is_transposed: bool = False):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].

Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
# get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group)
Expand All @@ -48,21 +61,37 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup
# [Q, K, V]
# to
# [Q1, Q2, K1, K2, V1, V2]
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
if is_transposed:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
else:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)

# rearrange the slice into the final order
# from
# [Q1, Q2, K1, K2, V1, V2]
# to
# [Q1, K1, V1], [Q2, K2, V2]
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)

if is_transposed:
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
else:
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0)
return weight_of_current_rank


def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
n_fused: int,
process_group: ProcessGroup,
is_transposed: bool = False):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].

Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
world_size = dist.get_world_size(group=process_group)

Expand All @@ -75,7 +104,11 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
qkv = qkv.cuda()
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
dist.all_gather(gather_list, qkv, group=process_group)
gather_weight = torch.cat(gather_list, dim=-1)

if is_transposed:
gather_weight = torch.cat(gather_list, dim=-1)
else:
gather_weight = torch.cat(gather_list, dim=0)
gather_weight = gather_weight.to(origin_device)
qkv = qkv.to(origin_device)

Expand All @@ -84,15 +117,23 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou
# [Q1, K1, V1, Q2, K2, V2]
# to
# [Q1, Q2, K1, K2, V1, V2]
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
if is_transposed:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
else:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)

reordered_chunk_list = []
for i in range(n_fused):
reordered_chunk_list.extend(weight_chunks[i::n_fused])
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)

if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
else:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0)
return reordered_gather_weight


class LinearConv1D_Col(ParallelModule):
class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.

The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
Expand Down Expand Up @@ -154,10 +195,10 @@ def __init__(self,
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)

def shard_fn(tensor):
return split_fused_qkv(tensor, self.n_fused, self.process_group)
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)

def gather_fn(tensor):
return gather_fused_qkv(tensor, 3, self.process_group)
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)

with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
Expand Down Expand Up @@ -202,21 +243,27 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

linear_1d = LinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)

# TODO: copy the sharded weights
with torch.no_grad():
sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group)
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=True)
linear_1d.weight.data.copy_(sharded_weight.data)

if bias:
sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group)
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=True)
linear_1d.bias.data.copy_(sharded_bias.data)

return linear_1d
Expand Down Expand Up @@ -254,7 +301,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
return output


class LinearConv1D_Row(ParallelModule):
class GPT2FusedLinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism.
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.

Expand Down Expand Up @@ -345,13 +392,13 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

linear_1d = LinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)

# TODO: copy the sharded weights
with torch.no_grad():
Expand Down
Loading