Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions colossalai/tensor/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get(self, rank_list: List[int], backend: str = 'nccl'):
if pg_key not in self.dict:

self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'NCCL initialize TP group on {rank_list}', ranks=[0])
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])

self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[pg_key]
Expand Down Expand Up @@ -63,7 +63,6 @@ def __init__(self,
self._rank_list = ranks
self._rank_list.sort() # ensure that the list is in order

self._rank_idx = self._rank_list.index(self._rank)
self._world_size = len(self._rank_list)

if dp_degree is None and tp_degree is None:
Expand All @@ -84,19 +83,22 @@ def __init__(self,
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
f"and TP degree {self._tp_degree}"

self._tp_rank_list = []
self._dp_rank_list = []
self._tp_rank_list = None
self._dp_rank_list = None

for idx, rank_id in enumerate(self._rank_list):
# idx and self._rank_idx in the same tp group
if idx % self._tp_degree == self._rank_idx % self._tp_degree:
self._dp_rank_list.append(rank_id)
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
self._tp_rank_list.append(rank_id)
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
if self._rank in i_tp_list:
self._tp_rank_list = i_tp_list

for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
if self._rank in j_dp_list:
self._dp_rank_list = j_dp_list

self._has_cpu_groups = False
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
self.is_init = True

def set_cpu_groups(self):
Expand All @@ -106,6 +108,7 @@ def set_cpu_groups(self):
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
self._has_cpu_groups = True

@property
def has_cpu_groups(self):
Expand Down Expand Up @@ -162,7 +165,9 @@ def tp_process_group(self):
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')

def cpu_dp_process_group(self):
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')

def cpu_tp_process_group(self):
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
40 changes: 21 additions & 19 deletions tests/test_tensor/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ColoTensor, ColoTensorSpec
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from tests.components_to_test.registry import non_distributed_component_funcs


def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))

for n, p in model.named_parameters():
p.set_process_group(pg)
if 'weight' in n and 'ln' not in n:
Expand Down Expand Up @@ -50,33 +47,39 @@ def check_grad_equal(model, torch_model, pg: ProcessGroup):


def run_gpt(init_spec_func, use_ddp):
set_seed(13234)
world_size = torch.distributed.get_world_size()

# build a PG with TP and DP hybrid
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))

# set seed make processes of the same tp group use the same seed
# set_seed(pg.tp_local_rank())

get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

# make sure torch_model and model has the same parameter values
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
if use_ddp:
# torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
# torch.distributed.barrier()
torch_model = DDP(torch_model,
device_ids=[gpc.get_global_rank()],
process_group=gpc.get_group(ParallelMode.DATA))

if use_ddp:
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model = ColoDDP(model, process_group=pg)

for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)

init_spec_func(model, pg)

check_param_equal(model, torch_model, pg)
model.train()
torch_model.train()
torch.distributed.barrier()

# close the dropout in eval mode
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
torch.distributed.barrier()
Comment thread
1SAA marked this conversation as resolved.
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
logits = model(colo_input, attn_mask)
Expand All @@ -92,26 +95,25 @@ def run_gpt(init_spec_func, use_ddp):
check_grad_equal(model, torch_model, pg)
if i > 0:
break
set_seed(313)


def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False])
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_gpt(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_gpt(4, False)
test_gpt(4, use_ddp=True)
18 changes: 11 additions & 7 deletions tests/test_tensor/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def run_1d_hybrid_tp(model_name):
split_param_row_tp1d(p, pg)

model = model.cuda()
model.train()
model.eval()
if rank == 0:
model_torch.train()
model_torch.eval()

colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))

Expand All @@ -89,6 +89,7 @@ def run_1d_hybrid_tp(model_name):
colo_optimizer.zero_grad()
if rank == 0:
optimizer_torch.zero_grad()
torch.distributed.barrier()

data = data.to(get_current_device())
label = label.to(get_current_device())
Expand All @@ -113,6 +114,7 @@ def run_1d_hybrid_tp(model_name):
output_torch = model_torch(data, label)
loss_torch = output_torch
assert torch.allclose(loss, loss_torch, rtol=1e-2)
torch.distributed.barrier()

loss.backward()
colo_optimizer.step()
Expand All @@ -125,7 +127,7 @@ def run_1d_hybrid_tp(model_name):
# check param
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())

torch.distributed.barrier()
if i > 5:
break

Expand Down Expand Up @@ -248,14 +250,15 @@ def run_1d_row_tp(model_name: str):
else:
output_torch = model_torch(data, label)
loss_torch = output_torch

if rank == 0:
assert torch.allclose(loss, loss_torch, rtol=1e-2)
torch.distributed.barrier()

loss.backward()

if rank == 0:
loss_torch.backward()
torch.distributed.barrier()

if i > 5:
break

Expand Down Expand Up @@ -296,8 +299,9 @@ def _run_pretrain_load():

def run_model_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for name in ['bert', 'simple_net']:
run_1d_row_tp(name)
# Comment below test for speed consideration
# for name in ['bert', 'simple_net']:
# run_1d_row_tp(name)
Comment thread
1SAA marked this conversation as resolved.
for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name)

Expand Down
66 changes: 36 additions & 30 deletions tests/test_tensor/test_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,25 @@
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor


def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.storage().size() > 0:
assert p.dtype == torch.half
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
pg.tp_world_size()), f'{torch_p} vs {p}'
assert p.dtype == torch.float16
assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}'


def check_grad_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.grad is not None:
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size())
torch.distributed.barrier()
print(torch.distributed.get_rank(), p.grad)
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size()), \
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'


def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
Expand All @@ -46,34 +49,35 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):

def init_1d_row_spec(model, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(*spec)
for n, p in model.named_parameters():
p.set_process_group(pg)
if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(*spec)


def init_1d_col_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(*spec)
for n, p in model.named_parameters():
p.set_process_group(pg)
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(*spec)


@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('use_chunk', [False])
@parameterize('use_zero', [False])
@parameterize('placement_policy', ['cuda'])
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda().half()
model = model.cuda()
torch_model = model_builder().cuda()

for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
torch_p.data.copy_(p.data)

world_size = torch.distributed.get_world_size()

Expand All @@ -93,23 +97,25 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pg)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32)
optim = ZeroOptimizer(optim, model, initial_scale=1)

amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())

# print(chunk_manager)
check_param_equal(model, torch_model, pg)
model.train()
torch_model.train()

model.eval()
torch_model.eval()

set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break

logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert tensor_equal(logits, torch_logits)
check_grad_equal(model, torch_model, pg)
Expand All @@ -123,13 +129,13 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec)
run_gpt(tp_init_spec_func=init_1d_row_spec)
# run_gpt(tp_init_spec_func=init_1d_row_spec)
else:
run_gpt()
run_gpt(tp_init_spec_func=init_1d_col_spec)


@pytest.mark.dist
@pytest.mark.skip("under development")
@pytest.mark.skip("buggy test")
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
Expand Down