-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[checkpoint] use gather_tensor in checkpoint and update its unit test #1339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
| from colossalai.tensor import ColoTensor, ColoTensorSpec | ||
| from colossalai.tensor.distspec import _DistSpec | ||
|
|
||
|
|
||
| def gather_tensor(colo_tensor: ColoTensor) -> None: | ||
| """Make colo_tensor replicated when the rank is 0 | ||
| """ | ||
| if not colo_tensor.is_replicate(): | ||
| pg = colo_tensor.get_process_group() | ||
| # for the group which contains rank 0 | ||
| if pg.tp_rank_list()[0] == 0: | ||
| old_dist_spec = colo_tensor.dist_spec | ||
| colo_tensor.to_replicate_() | ||
| if dist.get_rank() != 0: | ||
| colo_tensor.set_dist_spec(old_dist_spec) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line triggers collective communication.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no communication, since |
||
|
|
||
| # synchronize all processes for unexpected problems | ||
| dist.barrier() | ||
|
|
||
| if dist.get_rank() == 0: | ||
| setattr(colo_tensor, 'save_ready', True) # set saving signitrue | ||
|
|
||
|
|
||
| def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: | ||
| """Reversal operation of `gather_tensor`. | ||
| """ | ||
| if dist_spec.placement == 'r': | ||
| dist.broadcast(colo_tensor.data, 0) | ||
| else: | ||
| global_size = colo_tensor.size_global() | ||
|
|
||
| if dist.get_rank() == 0: | ||
| entire_data = colo_tensor.data | ||
| else: | ||
| entire_data = torch.empty(global_size, device=colo_tensor.device) | ||
| dist.broadcast(entire_data, 0) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| colo_tensor.set_dist_spec(dist_spec) | ||
| else: | ||
| rep_tensor = ColoTensor(entire_data, ColoTensorSpec( | ||
| pg=colo_tensor.get_process_group(), | ||
| compute_attr=colo_tensor.compute_spec)) | ||
| rep_tensor.set_dist_spec(dist_spec) | ||
| with torch.no_grad(): | ||
| colo_tensor.data.copy_(rep_tensor.data) | ||
| # synchronize all processes for unexpected problems | ||
| dist.barrier() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import os, shutil | ||
| import torch | ||
| import pytest | ||
| from copy import deepcopy | ||
| from functools import partial | ||
|
|
||
| import torch.multiprocessing as mp | ||
|
|
@@ -15,8 +16,7 @@ | |
| from colossalai.utils.cuda import get_current_device | ||
| from colossalai.utils import free_port | ||
| from colossalai.utils.model.colo_init_context import ColoInitContext | ||
| from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec | ||
| from colossalai.nn.parallel.data_parallel import ColoDDP | ||
| from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup | ||
| from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint | ||
| from colossalai.nn.optimizer import ColossalaiOptimizer | ||
|
|
||
|
|
@@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): | |
|
|
||
|
|
||
| def check_param_equal(model, torch_model): | ||
| for p, torch_p in zip(model.parameters(), torch_model.parameters()): | ||
| assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1) | ||
| for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): | ||
| assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape) | ||
|
|
||
|
|
||
| def remove(path): | ||
|
|
@@ -84,9 +84,13 @@ def compare_optims(optim1, optim2): | |
| if k not in state2: | ||
| continue | ||
| p2 = state2[k] | ||
| if isinstance(p1, ColoTensor): | ||
| assert isinstance(p2, ColoTensor) | ||
| assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1) | ||
| for n, t1 in p1.items(): | ||
| if n not in p2: | ||
| continue | ||
| t2 = p2[n] | ||
| if isinstance(t1, ColoTensor): | ||
| assert isinstance(t2, ColoTensor) | ||
| assert torch.allclose(t1, t2, rtol=0, atol=0) | ||
|
|
||
|
|
||
| def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): | ||
|
|
@@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch | |
| # set_seed(1) | ||
| with ColoInitContext(device=get_current_device()): | ||
| model = model_builder(checkpoint=True) | ||
| model_reload = model_builder(checkpoint=True) | ||
|
|
||
| if use_mp_reload: | ||
| if 'bert' == model_name: | ||
|
|
@@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch | |
| elif 'token_type_embeddings' in name and 'weight' in name: | ||
| init_1d_col_embedding(p, pg) | ||
| elif p.process_group.tp_world_size() == 1: | ||
| p.redistribute(ReplicaSpec(), pg) | ||
| p.set_process_group(pg) | ||
| elif "simple_net" == model_name: | ||
| init_spec_func(model, pg) | ||
|
|
||
| model_reload = deepcopy(model) | ||
| model = model.cuda() | ||
| model.train() | ||
| model.eval() | ||
|
|
||
| model_reload = model_reload.cuda() | ||
| model_reload.train() | ||
| model_reload.eval() | ||
|
|
||
| opt_class = torch.optim.Adam | ||
| colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) | ||
| colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) | ||
| run_reload = False | ||
|
|
||
| for i, (data, label) in enumerate(train_dataloader): | ||
|
|
||
| # Zero grad | ||
| colo_optimizer.zero_grad() | ||
| colo_optimizer_reload.zero_grad() | ||
|
|
||
| data = data.to(get_current_device()) | ||
| label = label.to(get_current_device()) | ||
|
|
@@ -155,43 +159,33 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch | |
| loss.backward() | ||
| loss_reload.backward() | ||
|
|
||
| if run_reload: | ||
| colo_optimizer_reload.zero_grad() | ||
| if criterion: | ||
| output_reload = model_reload(data) | ||
| loss_reload = criterion(output_reload, label) | ||
| else: | ||
| loss_reload = model_reload(data, label) | ||
| loss_reload.backward() | ||
| colo_optimizer_reload.step() | ||
| colo_optimizer.step() | ||
| colo_optimizer_reload.step() | ||
|
|
||
| if i > 2: | ||
| break | ||
|
|
||
| if not os.path.isdir('./checkpoint') and rank == 0: | ||
| os.mkdir('./checkpoint') | ||
| save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) | ||
| dist.barrier() | ||
| load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) | ||
| dist.barrier() | ||
|
|
||
| # Since model is sharded, we merge them before param checking. | ||
| for p in model.parameters(): | ||
| p.to_replicate_() | ||
|
|
||
| for p in model_reload.parameters(): | ||
| p.to_replicate_() | ||
| save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) | ||
| load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) | ||
|
|
||
| check_param_equal(model, model_reload) | ||
| compare_optims(colo_optimizer, colo_optimizer_reload) | ||
|
|
||
| if rank == 0: | ||
| remove('./checkpoint') | ||
| dist.barrier() | ||
|
|
||
|
|
||
| def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): | ||
| colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
| pg = ProcessGroup(tp_degree=world_size) | ||
| for model_name in ['simple_net', 'bert']: | ||
| # TODO(haichen) add BERT in the test | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inside a DP group, the input is replicated?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on which model is using. We do not have a unifited standard now. |
||
| # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context | ||
| for model_name in ['simple_net']: | ||
| _run_checkpoint(model_name, | ||
| init_1d_row_for_linear_weight_spec, | ||
| use_ddp, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the next PR. You can merge model and optim in a single file.
like
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html