From 814a608d0aed0692a4b539d5d73077d814e7e777 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 11 Feb 2025 14:21:59 +0800 Subject: [PATCH 1/5] fix checkpoint io for 3d --- .../hybrid_parallel_checkpoint_io.py | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1b7ae18889fd..f836b6b79aa7 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -5,6 +5,7 @@ from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple +from collections import defaultdict import torch import torch.distributed as dist @@ -12,6 +13,8 @@ from torch.distributed import ProcessGroup from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator +from torch.optim import Optimizer from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -724,26 +727,33 @@ def _get_param_id_from_optimizer_param( state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + self.load_states_into_optimizer(optimizer.optim, state_dict, id_map) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.shard_from_complete_optimizer_state( - state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict): + state_dict = {int(k): v for k, v in state_dict.items()} + new_states = defaultdict(dict) + master_to_working_map = optimizer.get_master_to_working_map() + for k, state in state_dict.items(): + if k in id_map: + param = id_map[k] + device = param.device + dtype = param.dtype + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + new_states[param] = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, dtype=dtype, inplace=True + ) + get_accelerator().synchronize() + optimizer.state.update(new_states) + def save_unsharded_model( self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False ): @@ -988,22 +998,7 @@ def _get_param_id_from_optimizer_param( for param in pg["params"]: param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) id_map[param_id] = param - load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) - - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - if param is None: - continue - device = param.device - if master_to_working_map is not None: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.shard_from_complete_optimizer_state( - state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True - ) - optimizer.optim.state[param] = sharded_state + self.load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map) sharded_optimizer_loading_epilogue(optimizer.optim) @@ -1086,6 +1081,7 @@ def shard_from_complete_optimizer_state( current_shape: torch.Size, original_shape: torch.Size, device: torch.device, + dtype: torch.dtype, inplace: bool, ) -> OrderedDict: """ @@ -1135,7 +1131,7 @@ def shard_from_complete_optimizer_state( slice_size = v.numel() // self.global_dp_size v = v.split(slice_size, dim=0)[self.dp_rank] - state_[k] = v.detach().clone().to(device) + state_[k] = v.detach().clone().to(dtype).to(device) return state_ From 4f4556b136c9771eb7f8fa68aba9473726e9bed5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 11 Feb 2025 14:32:57 +0800 Subject: [PATCH 2/5] fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index f836b6b79aa7..4a6dae7392a0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -40,7 +40,6 @@ load_shard_state_dict, load_state_dict, load_state_dict_into_model, - load_states_into_optimizer, save_config_file, save_param_groups, save_state_dict, @@ -727,7 +726,7 @@ def _get_param_id_from_optimizer_param( state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - self.load_states_into_optimizer(optimizer.optim, state_dict, id_map) + self.load_states_into_optimizer(optimizer, state_dict, id_map) loaded_file.add(filename) sharded_optimizer_loading_epilogue(optimizer.optim) @@ -752,7 +751,7 @@ def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_ state, current_shape=working_param.shape, original_shape=original_shape, device=device, dtype=dtype, inplace=True ) get_accelerator().synchronize() - optimizer.state.update(new_states) + optimizer.optim.state.update(new_states) def save_unsharded_model( self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False @@ -998,7 +997,7 @@ def _get_param_id_from_optimizer_param( for param in pg["params"]: param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) id_map[param_id] = param - self.load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map) + self.load_states_into_optimizer(optimizer, state_dict["state"], id_map) sharded_optimizer_loading_epilogue(optimizer.optim) From 797bd8b8c5c2720703b44d6d213c62a983246b2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Feb 2025 06:35:37 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../hybrid_parallel_checkpoint_io.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 4a6dae7392a0..b4927b437085 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,21 +1,21 @@ import copy import logging import os +from collections import defaultdict from functools import reduce from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple -from collections import defaultdict import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map -from colossalai.accelerator import get_accelerator -from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.padded_tensor import ( @@ -748,8 +748,13 @@ def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_ working_param = param original_shape = optimizer.param_info["param2shape"][id(working_param)] new_states[param] = self.shard_from_complete_optimizer_state( - state, current_shape=working_param.shape, original_shape=original_shape, device=device, dtype=dtype, inplace=True - ) + state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + dtype=dtype, + inplace=True, + ) get_accelerator().synchronize() optimizer.optim.state.update(new_states) From 60803fc77c0acd7058304fdd020ffd969b3714b1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 11 Feb 2025 15:09:33 +0800 Subject: [PATCH 4/5] Update hybrid_parallel_checkpoint_io.py --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b4927b437085..e47c0ed3ebed 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -755,8 +755,8 @@ def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_ dtype=dtype, inplace=True, ) - get_accelerator().synchronize() - optimizer.optim.state.update(new_states) + get_accelerator().synchronize() + optimizer.optim.state.update(new_states) def save_unsharded_model( self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False From dd01d46d2602de2fcada87d18cb4cc5aea1bc226 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 12 Feb 2025 10:13:37 +0800 Subject: [PATCH 5/5] fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index e47c0ed3ebed..bd814f426b68 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -15,7 +15,6 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map -from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.padded_tensor import ( @@ -755,7 +754,6 @@ def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_ dtype=dtype, inplace=True, ) - get_accelerator().synchronize() optimizer.optim.state.update(new_states) def save_unsharded_model( @@ -1135,7 +1133,7 @@ def shard_from_complete_optimizer_state( slice_size = v.numel() // self.global_dp_size v = v.split(slice_size, dim=0)[self.dp_rank] - state_[k] = v.detach().clone().to(dtype).to(device) + state_[k] = v.detach().clone().to(device=device, dtype=dtype) return state_