From 507fee002b217d61ebb848fd62bfc64239a952a5 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Thu, 14 Jul 2022 17:46:55 +0800 Subject: [PATCH 1/4] [Optimizer] add colo optimizer checkpoint --- .../utils/checkpoint/module_checkpoint.py | 47 ++++++++++++++++++- tests/test_utils/test_colo_checkpoint.py | 30 +++++++++--- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 119d719b2c33..ae90c943e66a 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -1,12 +1,22 @@ import torch import torch.distributed as dist from colossalai.tensor import ColoTensor, DistSpecManager +from colossalai.nn.optimizer import ColossalaiOptimizer +from copy import copy +from typing import Optional + + +def _print_optim_state(optimizer): + state = optimizer.state_dict()['state'] + print('optimizer state ', type(state), len(state)) + for k, v in state.items(): + print(k, v) def save_checkpoint(dire: str, epoch: int, model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, + optimizer: Optional[ColossalaiOptimizer] = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, *args, **kwargs): @@ -41,11 +51,21 @@ def save_checkpoint(dire: str, # delete the new dict del new_dict + optim_state_copy = copy(optimizer.state_dict()) + for k, v in optim_state_copy['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + t.to_replicate_() + if dist.get_rank() == 0: + model_state = {'epoch': epoch, 'optim': optim_state_copy} + torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch)) + del optim_state_copy + def load_checkpoint(dire, epoch: int, model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, + optimizer: Optional[ColossalaiOptimizer] = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, *args, **kwargs): @@ -74,3 +94,26 @@ def load_checkpoint(dire, for k, v in model.state_dict().items(): if isinstance(v, ColoTensor): v.set_tensor_spec(*mapping[k]) + + del mapping + mapping = dict() + + # _print_optim_state(optimizer) + + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + mapping[(k, n)] = (t.dist_spec, t.compute_spec) + t.to_replicate_() + + colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) + optimizer.load_state_dict(colo_checkpoint['optim']) + + for k, v in optimizer.state_dict()['state'].items(): + for n, t in v.items(): + if isinstance(t, ColoTensor): + # skip key not in mapping. + # For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer + if (k, n) not in mapping: + continue + t.set_tensor_spec(*mapping[(k, n)]) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index edc463b0dbea..f6575163e05f 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -117,7 +117,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch model_reload = model_reload.cuda() model_reload.train() - colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1)) + opt_class = torch.optim.Adam + colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) + colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) + run_reload = False for i, (data, label) in enumerate(train_dataloader): @@ -130,22 +133,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch # Bcast rank0 data to all processes if criterion: output = model(data) + output_reload = model_reload(data) loss = criterion(output, label) + loss_reload = criterion(output_reload, label) else: - output = model(data, label) - loss = output + loss = model(data, label) + loss_reload = model_reload(data, label) loss.backward() - colo_optimizer.step() + loss_reload.backward() + + if run_reload: + colo_optimizer_reload.zero_grad() + if criterion: + output_reload = model_reload(data) + loss_reload = criterion(output_reload, label) + else: + loss_reload = model_reload(data, label) + loss_reload.backward() + colo_optimizer_reload.step() if i > 2: break if not os.path.isdir('./checkpoint') and rank == 0: os.mkdir('./checkpoint') - save_checkpoint('./checkpoint', 0, model, None, None) + save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) + dist.barrier() + load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) dist.barrier() - load_checkpoint('./checkpoint', 0, model_reload, None, None) # Since model is sharded, we merge them before param checking. for p in model.parameters(): @@ -163,7 +179,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - for model_name in ['bert', 'simple_net']: + for model_name in ['simple_net']: _run_checkpoint(model_name, init_1d_row_for_linear_weight_spec, use_ddp, From f86faceff552b0b7b3dff665cf9ddadb55aa9ab0 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Thu, 14 Jul 2022 17:52:32 +0800 Subject: [PATCH 2/4] add colo optimizer checkpoint unittests --- colossalai/nn/optimizer/colossalai_optimizer.py | 3 --- tests/test_utils/test_colo_checkpoint.py | 14 +++++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py index fb0c43903509..34f5a9541975 100644 --- a/colossalai/nn/optimizer/colossalai_optimizer.py +++ b/colossalai/nn/optimizer/colossalai_optimizer.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - import torch import torch.nn as nn from torch import Tensor diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index f6575163e05f..be4d62c6cb2b 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -77,6 +77,18 @@ def remove(path): raise ValueError("file {} is not a file or dir.".format(path)) +def compare_optims(optim1, optim2): + state1 = optim1.state_dict()['state'] + state2 = optim2.state_dict()['state'] + for k, p1 in state1.items(): + if k not in state2: + continue + p2 = state2[k] + if isinstance(p1, ColoTensor): + assert isinstance(p2, ColoTensor) + assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1) + + def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -171,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch p.to_replicate_() check_param_equal(model, model_reload) - + compare_optims(colo_optimizer, colo_optimizer_reload) if rank == 0: remove('./checkpoint') From cba3d0345fd34296335249763d5be14d6fb3f593 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Thu, 14 Jul 2022 17:55:23 +0800 Subject: [PATCH 3/4] polish --- tests/test_utils/test_colo_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index be4d62c6cb2b..524a39be1d9d 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -191,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - for model_name in ['simple_net']: + for model_name in ['simple_net', 'bert']: _run_checkpoint(model_name, init_1d_row_for_linear_weight_spec, use_ddp, From a3054221553ac8327c48327a24b5dd0ace8bc829 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Thu, 14 Jul 2022 19:38:53 +0800 Subject: [PATCH 4/4] polish code --- colossalai/utils/checkpoint/module_checkpoint.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index ae90c943e66a..81370ad0fff5 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -6,13 +6,6 @@ from typing import Optional -def _print_optim_state(optimizer): - state = optimizer.state_dict()['state'] - print('optimizer state ', type(state), len(state)) - for k, v in state.items(): - print(k, v) - - def save_checkpoint(dire: str, epoch: int, model: torch.nn.Module, @@ -26,7 +19,7 @@ def save_checkpoint(dire: str, dire (str): directory to save the checkpoint files. epoch (int): the number of epoch model (torch.nn.Module): a torch module initialized by ColoInitContext - optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. + optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ @@ -76,7 +69,7 @@ def load_checkpoint(dire, epoch (int): _description_ rank (int): _description_ model (torch.nn.Module): _description_ - optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. + optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. """ @@ -98,8 +91,6 @@ def load_checkpoint(dire, del mapping mapping = dict() - # _print_optim_state(optimizer) - for k, v in optimizer.state_dict()['state'].items(): for n, t in v.items(): if isinstance(t, ColoTensor):