Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@


def get_obj_list_element(obj, a):
r"""
Get the element of the list in the object
"""
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(a)
Expand Down
19 changes: 6 additions & 13 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from .dropout import Dropout1D
from .embedding1d import Embedding1D
from .layernorm1d import LayerNorm1D
from .linear1d import Linear1D_Col, Linear1D_Row
from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row
from .vocabparallelembedding1d import VocabParallelEmbedding1D
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d

__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D_Col",
"Linear1D_Row",
"LinearConv1D_Col",
"LinearConv1D_Row",
"LayerNorm1D",
"Dropout1D",
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d"
]
2 changes: 0 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
import torch.distributed as dist

from colossalai.core import global_context as gpc

try:
import fused_mix_prec_layer_norm_cuda
except:
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import torch.nn as nn
from torch.distributed import ProcessGroup

from .parallelmodule import ParallelModule
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ['Dropout1D']


class Dropout1D(ParallelModule, nn.Dropout):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from collections import OrderedDict
from typing import Callable, List, Union

import torch
Expand All @@ -12,26 +11,148 @@
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter

from colossalai.context import ParallelMode, seed
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_rowwise
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device

from ._operation import reduce_input
from .parallelmodule import ParallelModule
from ._operation import gather_forward_split_backward, reduce_input
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset

Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
__all__ = ['Embedding1D', 'VocabParallelEmbedding1D']


class VocabParallelEmbedding1D(ParallelLayer):
class Embedding1D(ParallelModule):
r"""Embedding for 1D parallelism.

Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed “pad”, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.

The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
::

max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.

More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.

More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""

def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()

self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)

self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output

if device is None:
device = get_current_device()

self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)

with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)

@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
# get the attributes
num_embedding = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
max_norm = module.max_norm
norm_type = module.norm_type
scale_grad_by_freq = module.scale_grad_by_freq
sparse = module.sparse
dtype = module.weight.dtype
device = module.weight.device

# sparse is not support yet
if sparse:
raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.")

embedding = Embedding1D(num_embeddings=num_embedding,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
process_group=process_group,
dtype=dtype,
device=device,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
*args,
**kwargs)

# copy the weight
with torch.no_grad():
sharded_weight = shard_colwise(module.weight.data, process_group)
embedding.weight.copy_(sharded_weight)

return embedding

def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel


class VocabParallelEmbedding1D(ParallelModule):
r"""Embedding parallelized in the vocabulary dimension.

Args:
Expand Down Expand Up @@ -93,9 +214,7 @@ def __init__(self,
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)

with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
self.reset_parameters(weight_initializer)

@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
Expand Down Expand Up @@ -132,7 +251,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup,
return vocab_embedding_1d

def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
Expand All @@ -143,16 +262,6 @@ def _fill_padding_idx_with_zero(self) -> None:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)

def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)

def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
Expand Down
Loading