From 3dfba239632b9fe4f8ba5233258c0ded482eeaa3 Mon Sep 17 00:00:00 2001 From: lclgy Date: Mon, 28 Aug 2023 15:56:53 +0800 Subject: [PATCH 1/7] fix zero ckptio with offload --- colossalai/zero/low_level/low_level_optim.py | 16 ++++++++++------ .../test_low_level_zero_checkpoint_io.py | 9 +++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8f2232393240..44f9a037a5cd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -528,9 +528,12 @@ def state_dict(self) -> Dict: for k, v in state.items(): if isinstance(v, torch.Tensor) and k != 'step': working_param = self._param_store.master_to_working_param[id(param)] - gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] - dist.all_gather(gather_tensor, v, group=self.dp_pg) - param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + gather_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param).to(v.device) zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -585,9 +588,10 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for k, v in states.items(): if isinstance(v, torch.Tensor) and k != 'step': - state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] - dist.all_gather(state_tensor, v, group=self.dp_pg) - state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)] + dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param).to(v.device) current_block_size += state_tensor.numel() current_block[k] = state_tensor diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index a94e8d42c78e..69fb7642e31e 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -17,10 +17,11 @@ @clear_cache_before_run() -@parameterize('stage', [2]) +@parameterize('stage', [1, 2]) @parameterize('shard', [True, False]) -def check_low_level_zero_checkpointIO(stage: int, shard: bool): - plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32) +@parameterize('offload', [True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) model = resnet18() criterion = lambda x: x.mean() @@ -50,7 +51,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) def run_dist(rank, world_size, port): From c66b42752f0ed4376b983e6c173acf2c0089507e Mon Sep 17 00:00:00 2001 From: lclgy Date: Mon, 28 Aug 2023 16:21:49 +0800 Subject: [PATCH 2/7] fix load device --- colossalai/zero/low_level/low_level_optim.py | 6 ++---- .../test_checkpoint_io/test_low_level_zero_checkpoint_io.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 44f9a037a5cd..a7f5e486f226 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -80,9 +80,6 @@ def __init__( tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): - # TODO: - # 1. state_dict for checkpoint IO - super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() @@ -556,7 +553,8 @@ def load_state_dict(self, state_dict: Dict): if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) v_list = v.split(v.numel() // self._world_size) - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach() + device = 'cpu' if self._cpu_offload else 'cuda' + zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach() self.optim.load_state_dict(zero_state_dict) zero_state_dict = dict() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 69fb7642e31e..d820fcde39c8 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -19,7 +19,7 @@ @clear_cache_before_run() @parameterize('stage', [1, 2]) @parameterize('shard', [True, False]) -@parameterize('offload', [True]) +@parameterize('offload', [False, True]) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) From fec402c8243f47a87ada8a9a85375fbed93a5b24 Mon Sep 17 00:00:00 2001 From: lclgy Date: Mon, 28 Aug 2023 18:38:13 +0800 Subject: [PATCH 3/7] saved tensors in ckpt should be on CPU --- colossalai/zero/low_level/low_level_optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a7f5e486f226..96d5902e893f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -530,7 +530,7 @@ def state_dict(self) -> Dict: ] dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).to(v.device) + working_param).cpu() zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -589,7 +589,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)] dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).to(v.device) + working_param).cpu() current_block_size += state_tensor.numel() current_block[k] = state_tensor From 67783582fae0242744152fb57edc1c3bc51b4772 Mon Sep 17 00:00:00 2001 From: lclgy Date: Tue, 29 Aug 2023 11:50:55 +0800 Subject: [PATCH 4/7] fix unit test --- tests/test_zero/test_low_level/test_zero_ckpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index 23356fe718a6..421e1044715a 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): atol = 4e-3 a = a.detach().to(dtype) - b = b.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) assert_close(a, b, rtol=rtol, atol=atol) @@ -90,6 +90,7 @@ def exam_zero_1_torch_ddp_ckpt(): # examine the original state dict for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): for t_v, z_v in zip(torch_state.values(), zero_state.values()): + print(t_v, z_v) loose_close(t_v, z_v) # empty the optimzer state From 1b440a86a6c58fb8f85fc6da6fbe7e6af1f3435a Mon Sep 17 00:00:00 2001 From: lclgy Date: Tue, 29 Aug 2023 17:41:26 +0800 Subject: [PATCH 5/7] fix unit test --- tests/test_zero/test_low_level/test_zero_ckpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index 421e1044715a..ab811c6b4d3c 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -90,7 +90,6 @@ def exam_zero_1_torch_ddp_ckpt(): # examine the original state dict for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): for t_v, z_v in zip(torch_state.values(), zero_state.values()): - print(t_v, z_v) loose_close(t_v, z_v) # empty the optimzer state From 0193e4717d67145a1a70c4eaa882f70b58fa9fbe Mon Sep 17 00:00:00 2001 From: lclgy Date: Wed, 30 Aug 2023 10:49:15 +0800 Subject: [PATCH 6/7] add clear cache --- tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index d820fcde39c8..93c70c9a450f 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -60,6 +60,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_low_level_zero_checkpointIO(): spawn(run_dist, 2) From 8a8c9cfdbaadaa421036c7774a435f54c661aaeb Mon Sep 17 00:00:00 2001 From: lclgy Date: Wed, 30 Aug 2023 11:01:40 +0800 Subject: [PATCH 7/7] save memory for CI --- .../test_low_level_zero_checkpoint_io.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 93c70c9a450f..3faa395b5935 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -16,8 +16,10 @@ ) +# stage 1 and 2 process the optimizer/mode the same way +# only test 2 is fine @clear_cache_before_run() -@parameterize('stage', [1, 2]) +@parameterize('stage', [2]) @parameterize('shard', [True, False]) @parameterize('offload', [False, True]) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): @@ -28,8 +30,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): optimizer = HybridAdam((model.parameters()), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - x = torch.randn(4, 3, 224, 224) - x = x.to('cuda') + x = torch.randn(1, 3, 224, 224, device='cuda') output = model(x) loss = criterion(output) booster.backward(loss, optimizer) @@ -57,6 +58,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): def run_dist(rank, world_size, port): colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') check_low_level_zero_checkpointIO() + torch.cuda.empty_cache() @rerun_if_address_is_in_use()