Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions colossalai/shardformer/_utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,57 @@
import re


def get_obj_list_element(obj, a):
def get_obj_list_element(obj, attr: str):
r"""
Get the element of the list in the object

If the attr is a normal attribute, return the attribute of the object.
If the attr is a index type, return the element of the index in the list, like `layers[0]`.

Args:
obj (Object): The object to get
attr (str): The suffix of the attribute to get

"""
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(a)
result = prog.search(attr)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
a_ = a.replace(matched_brackets, '')
container_obj = getattr(obj, a_)
attr_ = attr.replace(matched_brackets, '')
container_obj = getattr(obj, attr_)
obj = container_obj[int(matched_index)]
else:
obj = getattr(obj, a)
obj = getattr(obj, attr)
return obj


def set_obj_list_element(obj, attr: str, value):
r"""
Set the element to value of a list object

It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value

Args:
obj (object): The object to set
attr (str): the string including a list index like `layers[0]`
"""
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(attr)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
attr_ = attr.replace(matched_brackets, '')
container_obj = getattr(obj, attr_)
container_obj[int(matched_index)] = value
else:
setattr(obj, attr, value)


def hasattr_(obj, attr: str):
r"""
Check whether the object has the multi sublevel attr
Expand Down Expand Up @@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
if ignore:
return
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
setattr(obj, attrs[-1], value)
set_obj_list_element(obj, attrs[-1], value)


def getattr_(obj, attr: str, ignore: bool = False):
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm', 'FusedRMSNorm'
'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col'
]
175 changes: 174 additions & 1 deletion colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
Expand All @@ -31,7 +32,7 @@
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row']

# ====================================
# For GPT Only
Expand Down Expand Up @@ -471,3 +472,175 @@ def forward(self, input_: Tensor) -> Tensor:
return output
else:
return output, self.bias


# ====================================
# For Fused torch.nn.Linear
# ====================================


class FusedLinear1D_Col(ParallelModule):
r"""Fused Linear layer with column parallelism.

The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.

Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.

More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""

def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()

# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication

if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)

def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)

def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False)

with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
self.weight = customized_distributed_tensor_to_param(sharded_weight)

if bias:
bias = torch.empty(self.out_features, **factory_kwargs)

with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
self.bias = customized_distributed_tensor_to_param(sharded_bias)
else:
self.bias = None

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)

# init weights
self.reset_parameters(weight_initializer, bias_initializer)

@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
*args, **kwargs) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.

Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
"""
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device

# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

linear_1d = FusedLinear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)

# TODO: copy the sharded weights
with torch.no_grad():
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=False)
linear_1d.weight.data.copy_(sharded_weight.data)

if bias:
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=False)
linear_1d.bias.data.copy_(sharded_bias.data)

return linear_1d

def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)

def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None

output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)

if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
else:
output = output_parallel

if self.skip_bias_add:
return output, self.bias
else:
return output
41 changes: 41 additions & 0 deletions colossalai/shardformer/modeling/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup


def forward_fn():

def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
-1).permute(2, 0, 3, 1, 4))
# q, k, v with shape (batch_size * nHead, height * width, channel)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)

attn_weights = (query * self.scale) @ key.transpose(-2, -1)

if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w,
(height, width), (height, width))

attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)

# replace dropout process with added DropoutForParallelInput layer
# origin code:
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = self.dropout_layer(attn_weights)

attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)

attn_output = self.proj(attn_output)

if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)

return outputs

return forward
6 changes: 5 additions & 1 deletion colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class PolicyLocation:
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),

# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel":
PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"),
Expand All @@ -100,6 +100,10 @@ class PolicyLocation:
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),

# Sam
"transformers.models.sam.modeling_sam.SamModel":
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
}


Expand Down
Loading