Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions colossalai/nn/optimizer/colossalai_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
from torch import Tensor
Expand Down
42 changes: 38 additions & 4 deletions colossalai/utils/checkpoint/module_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager
from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy
from typing import Optional


def save_checkpoint(dire: str,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
Expand All @@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
dire (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""

Expand All @@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
# delete the new dict
del new_dict

optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
t.to_replicate_()
if dist.get_rank() == 0:
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
del optim_state_copy


def load_checkpoint(dire,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
Expand All @@ -56,7 +69,7 @@ def load_checkpoint(dire,
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""

Expand All @@ -74,3 +87,24 @@ def load_checkpoint(dire,
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k])

del mapping
mapping = dict()

for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
t.to_replicate_()

colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])

for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping:
continue
t.set_tensor_spec(*mapping[(k, n)])
44 changes: 36 additions & 8 deletions tests/test_utils/test_colo_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path))


def compare_optims(optim1, optim2):
state1 = optim1.state_dict()['state']
state2 = optim2.state_dict()['state']
for k, p1 in state1.items():
if k not in state2:
continue
p2 = state2[k]
if isinstance(p1, ColoTensor):
assert isinstance(p2, ColoTensor)
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)


def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
Expand Down Expand Up @@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload = model_reload.cuda()
model_reload.train()

colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False

for i, (data, label) in enumerate(train_dataloader):

Expand All @@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# Bcast rank0 data to all processes
if criterion:
output = model(data)
output_reload = model_reload(data)
loss = criterion(output, label)
loss_reload = criterion(output_reload, label)
else:
output = model(data, label)
loss = output
loss = model(data, label)
loss_reload = model_reload(data, label)

loss.backward()
colo_optimizer.step()
loss_reload.backward()

if run_reload:
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step()

if i > 2:
break

if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None)
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, None, None)

# Since model is sharded, we merge them before param checking.
for p in model.parameters():
Expand All @@ -155,15 +183,15 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
p.to_replicate_()

check_param_equal(model, model_reload)

compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0:
remove('./checkpoint')


def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
for model_name in ['bert', 'simple_net']:
for model_name in ['simple_net', 'bert']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,
Expand Down