Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions colossalai/shardformer/layer/linear_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,15 @@ def __init__(self,
self.reset_parameters(weight_initializer, bias_initializer)

@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int,
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
*args, **kwargs) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.

Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
"""
# get the attributes
in_features = module.weight.shape[0]
Expand Down Expand Up @@ -136,20 +141,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis

# first rearange the order of weight and bias
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_cast)
order = torch.arange(world_size * n_fused)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)

weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1)
weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1)
rearanged_weight_chunks = [weight_chunks[i] for i in new_order]
rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1)
sharded_weight = shard_colwise(rearanged_weight, process_group)
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())

if bias:
bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0)
bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0)
rearanged_bias_chunks = [bias_chunks[i] for i in new_order]
rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0)
sharded_bias = shard_colwise(rearanged_bias, process_group)
Expand Down Expand Up @@ -262,8 +267,8 @@ def __init__(self,
self.reset_parameters(weight_initializer, bias_initializer)

@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int,
*args, **kwargs) -> ParallelModule:
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
Expand Down Expand Up @@ -291,26 +296,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row

# first rearange the order of weight and bias
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_cast)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)

weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0)
rearanged_weight_chunks = [weight_chunks[i] for i in new_order]
rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0)
sharded_weight = shard_rowwise(rearanged_weight, process_group)
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())

if bias:
bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0)
rearanged_bias_chunks = [bias_chunks[i] for i in new_order]
rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0)
linear_1d.bias.copy_(rearanged_bias.contiguous())
linear_1d.bias.copy_(module.bias.data)

return linear_1d

Expand Down
10 changes: 2 additions & 8 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,23 @@ def module_policy(self):
suffix="attn.c_attn",
target_module=col_nn.LinearConv1D_Col,
kwargs={
"n_cast": 3,
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.LinearConv1D_Row,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.LinearConv1D_Col,
kwargs={
"n_cast": 1,
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.LinearConv1D_Row,
kwargs={
"n_cast": 1,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
Expand Down
107 changes: 107 additions & 0 deletions tests/test_shardformer/test_layer/test_linearconv_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close

import colossalai
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
from colossalai.testing import rerun_if_address_is_in_use, spawn


# This code is copied from https://github.com/huggingface/transformers
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

Basically works like a linear layer but the weights are transposed.

Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""

def __init__(self, nf, nx):
super().__init__()
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)

def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x


def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)

tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor


def check_linear_conv_1d_col():
linear = Conv1D(192, 48).cuda()
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)

assert linear_conv_col.weight.shape == torch.Size([96, 48])
assert linear_conv_col.bias.shape == torch.Size([96])

# check computation correctness
x = torch.rand(4, 48).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)

# check backward correctness
out.sum().backward()
gather_out.sum().backward()

rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad)


def check_linear_1d_row():
linear = Conv1D(192, 48).cuda()
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)

assert linear_row.weight.shape == torch.Size([192, 24])
assert linear_row.bias.shape == torch.Size([192])

# check computation correctness
x = torch.rand(4, 48).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)

# check backward correctness
out.sum().backward()
gather_out.sum().backward()

rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, linear_row.weight.grad)


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_conv_1d_col()


@rerun_if_address_is_in_use()
def test_linearconv():
spawn(run_dist, nprocs=2)


if __name__ == '__main__':
test_linearconv()
3 changes: 0 additions & 3 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def check_gpt2(rank, world_size, port):

sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
print(name)
# if name == 'transformers_gpt':
# continue
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

Expand Down