Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/tensor/colo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def view_global(self, *args) -> 'ColoTensor':
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)

def size_global(self, args: Optional[int] = None):
def size_global(self, args: Optional[int] = None) -> torch.Size:
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
Expand Down
9 changes: 9 additions & 0 deletions colossalai/tensor/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,18 @@ def __eq__(self, obj: 'ProcessGroup') -> bool:
def rank(self):
return self._rank

def ranks_in_group(self):
return self._rank_list

def world_size(self):
return self._world_size

def tp_rank_list(self):
return self._tp_rank_list

def dp_rank_list(self):
return self._dp_rank_list

def tp_local_rank(self):
return self._rank % self._tp_degree

Expand Down
118 changes: 68 additions & 50 deletions colossalai/utils/checkpoint/module_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager
from colossalai.tensor import ColoTensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from typing import Optional


Expand All @@ -22,37 +22,52 @@ def save_checkpoint(dire: str,
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""

mapping = dict()
new_dict = dict()

rank = dist.get_rank()
model_state = model.state_dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model.state_dict().items():
for k, v in model_state.items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach()
else:
new_dict[k] = v
if dist.get_rank() == 0:
for k, v in new_dict.items():
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model

if rank == 0:
# sanity check
for k, v in model_state.items():
if isinstance(v, ColoTensor):
assert v.save_ready
assert v.is_replicate()
delattr(v, 'save_ready')
# model saving
save_state = {'epoch': epoch, 'model': model_state}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next PR. You can merge model and optim in a single file.

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

like
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch))

model_state = {'epoch': epoch, 'model': new_dict}
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))

# delete the new dict
del new_dict
# delete old dicts
del model_state
# synchronize all the processes
dist.barrier()

optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
mapping = dict()
optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
t.to_replicate_()
if dist.get_rank() == 0:
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
del optim_state_copy
mapping[(k, n)] = t.dist_spec
gather_tensor(t)

if rank == 0:
save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')

del optim_state
del mapping
dist.barrier()


def load_checkpoint(dire,
Expand All @@ -72,39 +87,42 @@ def load_checkpoint(dire,
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""

rank = dist.get_rank()
mapping = dict()
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
v.to_replicate_()

model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model.load_state_dict(model_state['model'])

# reset tensors to original dist spec.
with DistSpecManager.no_grad():
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k])

for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
mapping[n] = p.dist_spec
gather_tensor(p)

if rank == 0:
load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model.load_state_dict(load_state['model'])
dist.barrier()

# scatter loaded parameters
for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
scatter_tensor(p, mapping[n])
if rank == 0:
assert hasattr(p, 'save_ready')
delattr(p, 'save_ready')
del mapping
mapping = dict()

mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
t.to_replicate_()
mapping[(k, n)] = t.dist_spec
gather_tensor(t)

colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
if rank == 0:
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
dist.barrier()

for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping:
continue
t.set_tensor_spec(*mapping[(k, n)])
scatter_tensor(t, mapping[(k, n)])

del mapping
50 changes: 50 additions & 0 deletions colossalai/utils/checkpoint/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.distspec import _DistSpec


def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
if not colo_tensor.is_replicate():
pg = colo_tensor.get_process_group()
# for the group which contains rank 0
if pg.tp_rank_list()[0] == 0:
old_dist_spec = colo_tensor.dist_spec
colo_tensor.to_replicate_()
if dist.get_rank() != 0:
colo_tensor.set_dist_spec(old_dist_spec)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line triggers collective communication.
Will there be potential blocking if rank 0 is excluded?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no communication, since old_dist_spec must be SHARD and we have a replicated tensor here.


# synchronize all processes for unexpected problems
dist.barrier()

if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signitrue


def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == 'r':
dist.broadcast(colo_tensor.data, 0)
else:
global_size = colo_tensor.size_global()

if dist.get_rank() == 0:
entire_data = colo_tensor.data
else:
entire_data = torch.empty(global_size, device=colo_tensor.device)
dist.broadcast(entire_data, 0)

if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
pg=colo_tensor.get_process_group(),
compute_attr=colo_tensor.compute_spec))
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)
# synchronize all processes for unexpected problems
dist.barrier()
56 changes: 25 additions & 31 deletions tests/test_utils/test_colo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os, shutil
import torch
import pytest
from copy import deepcopy
from functools import partial

import torch.multiprocessing as mp
Expand All @@ -15,8 +16,7 @@
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.optimizer import ColossalaiOptimizer

Expand Down Expand Up @@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):


def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)


def remove(path):
Expand All @@ -84,9 +84,13 @@ def compare_optims(optim1, optim2):
if k not in state2:
continue
p2 = state2[k]
if isinstance(p1, ColoTensor):
assert isinstance(p2, ColoTensor)
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
for n, t1 in p1.items():
if n not in p2:
continue
t2 = p2[n]
if isinstance(t1, ColoTensor):
assert isinstance(t2, ColoTensor)
assert torch.allclose(t1, t2, rtol=0, atol=0)


def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
Expand All @@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
model_reload = model_builder(checkpoint=True)

if use_mp_reload:
if 'bert' == model_name:
Expand All @@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
elif 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
elif p.process_group.tp_world_size() == 1:
p.redistribute(ReplicaSpec(), pg)
p.set_process_group(pg)
elif "simple_net" == model_name:
init_spec_func(model, pg)

model_reload = deepcopy(model)
model = model.cuda()
model.train()
model.eval()

model_reload = model_reload.cuda()
model_reload.train()
model_reload.eval()

opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False

for i, (data, label) in enumerate(train_dataloader):

# Zero grad
colo_optimizer.zero_grad()
colo_optimizer_reload.zero_grad()

data = data.to(get_current_device())
label = label.to(get_current_device())
Expand All @@ -155,43 +159,33 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
loss.backward()
loss_reload.backward()

if run_reload:
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step()
colo_optimizer.step()
colo_optimizer_reload.step()

if i > 2:
break

if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier()

# Since model is sharded, we merge them before param checking.
for p in model.parameters():
p.to_replicate_()

for p in model_reload.parameters():
p.to_replicate_()
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)

check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload)

if rank == 0:
remove('./checkpoint')
dist.barrier()


def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
for model_name in ['simple_net', 'bert']:
# TODO(haichen) add BERT in the test
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside a DP group, the input is replicated?

Copy link
Copy Markdown
Contributor Author

@1SAA 1SAA Jul 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on which model is using. We do not have a unifited standard now.

# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for model_name in ['simple_net']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,
Expand Down
Loading