Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
from contextlib import contextmanager

import torch
Expand All @@ -14,7 +13,8 @@ class SeedManager:

def __init__(self):
original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
# TODO: unify this seed manager with the colossalai.context.random
seed = os.getpid()
Comment thread
FrankLeeeee marked this conversation as resolved.
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)
Expand Down
16 changes: 8 additions & 8 deletions colossalai/shardformer/shard/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shard_config import ShardConfig

dim_mapping = {Col_Layer: 1, Row_Layer: 0}
dim_mapping = {Col_Layer: 0, Row_Layer: 1}


class Slicer():
Expand Down Expand Up @@ -40,7 +40,7 @@ def slice_weight_bias(
# print(weight.shape, dim)
if policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
bias = self.slice_tensor(bias, 0, True)
bias = self.slice_tensor(bias, 0, True, n_cast)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
Expand Down Expand Up @@ -129,13 +129,13 @@ def slice_col(

"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
return torch.cat(chunk_list, dim=1).contiguous()

def slice_row(
self,
Expand All @@ -152,10 +152,10 @@ def slice_row(
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=1).contiguous()
return torch.cat(chunk_list, dim=0).contiguous()
51 changes: 51 additions & 0 deletions tests/test_shardformer/test_module/test_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch
import torch.nn.functional as F

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)


def check_dropout(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')

# prepare data
input = torch.randn(5, 4).to('cuda')
dropout = Dropout1D(p=0.4).to('cuda')
output_list = []
# compare the dropout pattern in each device
for i in range(2):
output = dropout(input)
output_list.append(output)
dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)]
torch.distributed.all_gather(dist_output_list, output)
for j in range(world_size):
for k in range(world_size):
if j != k:
mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0)
assert torch.all(
mask
) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}"
# compare the dropout pattern in loacl device
for i in range(len(output_list)):
for j in range(len(output_list)):
if i != j:
mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0)
assert torch.all(
mask
) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dropout():
spawn(check_dropout, 2)


if __name__ == '__main__':
test_dropout()
78 changes: 78 additions & 0 deletions tests/test_shardformer/test_module/test_slicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import torch
import torch.nn.functional as F

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer
from colossalai.shardformer.shard.shard_config import ShardConfig
from colossalai.shardformer.shard.slicer import Slicer
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)


def check_slicer(rank, world_size, port, in_feature, out_feature):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
# initialize slicer
shardconfig = ShardConfig(rank=rank, world_size=world_size)
slicer = Slicer(shardconfig)
# initialize test data
weight = torch.randn(in_feature, out_feature)
bias = torch.randn(out_feature)
policy_layer_cls_list = [Layer, Col_Layer, Row_Layer]
n_cast_list = [None, 2, 3, 4]
# weight and bias
for n_cast in n_cast_list:
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast)
expected_sliced_weight = weight
expected_sliced_bias = bias
assert torch.equal(
sliced_weight, expected_sliced_weight
), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
assert torch.equal(
sliced_bias, expected_sliced_bias
), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"

sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast)
if (n_cast is None):
expected_sliced_weight = weight.chunk(world_size, dim=0)[rank]
expected_sliced_bias = bias.chunk(world_size)[rank]
else:
chunks = weight.chunk(world_size * n_cast, dim=0)
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0)
chunks = bias.chunk(world_size * n_cast, dim=0)
expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)])
assert torch.equal(
sliced_weight, expected_sliced_weight
), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
assert torch.equal(
sliced_bias, expected_sliced_bias
), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}"

sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast)
if (n_cast is None):
expected_sliced_weight = weight.chunk(world_size, dim=1)[rank]
expected_sliced_bias = bias
else:
chunks = weight.chunk(world_size * n_cast, dim=1)
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1)
expected_sliced_bias = bias
assert torch.equal(
sliced_weight, expected_sliced_weight
), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
assert torch.equal(
sliced_bias, expected_sliced_bias
), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_slicer():
args = dict(in_feature=24, out_feature=48)
spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature'])


if __name__ == '__main__':
test_slicer()