From 08e3616282b02dee980dffc5e9e5c1cd9f793dbf Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Nov 2024 14:11:07 +0800 Subject: [PATCH 01/10] support async optimizer save/load --- colossalai/booster/booster.py | 5 +- .../booster/plugin/low_level_zero_plugin.py | 49 ++++++++- .../checkpoint_io/checkpoint_io_base.py | 26 +++-- colossalai/checkpoint_io/utils.py | 17 ++- colossalai/testing/comparison.py | 2 +- colossalai/utils/safetensors.py | 104 +++++++++++++++--- .../test_low_level_zero_checkpoint_io.py | 11 +- .../test_safetensors_async_io.py | 52 +++++++++ 8 files changed, 232 insertions(+), 34 deletions(-) create mode 100644 tests/test_checkpoint_io/test_safetensors_async_io.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 2518b25111a4..ba569a49c5ed 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -358,6 +358,7 @@ def save_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ) -> None: """ Save optimizer to checkpoint. @@ -373,7 +374,9 @@ def save_optimizer( names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ - self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) + self.checkpoint_io.save_optimizer( + optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async + ) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None: """Save lr scheduler to checkpoint. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d4eb1bbed75a..9e333ffe73e4 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -24,6 +24,7 @@ get_shard_filename, load_param_groups_into_optimizer, load_shard_state_dict, + load_state_dict, load_states_into_optimizer, save_param_groups, save_state_dict, @@ -113,7 +114,9 @@ def _hook_context(self): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False + ): """Save optimizer to checkpoint but only on master process. Args: @@ -127,7 +130,26 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, # the communication on each rank would not match state_dict = optimizer.state_dict() if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + from colossalai.utils.safetensors import save_nested + + f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") + save_nested(f_writer, state_dict["state"], state_dict["param_groups"]) + self.async_writers.append(f_writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + use_async = checkpoint.endswith(".safetensors") + if use_async: + from colossalai.utils.safetensors import load_flat + + checkpoint = load_flat(checkpoint) + else: + checkpoint = load_state_dict(checkpoint) + optimizer.load_state_dict(checkpoint) def save_sharded_optimizer( self, @@ -136,6 +158,7 @@ def save_sharded_optimizer( gather_dtensor: bool = False, prefix: str = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save sharded Zero-optimizer checkpoint under the given checkpointing path. @@ -164,7 +187,7 @@ def save_sharded_optimizer( sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) @@ -184,7 +207,18 @@ def save_sharded_optimizer( checkpoint_file_path = os.path.join(checkpoint, shard_file) if self.coordinator.is_master(): - save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + from colossalai.utils.safetensors import save_nested + + f_writer = AsyncFileWriter( + fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" + ) + save_nested(f_writer, shard) + self.async_writers.append(f_writer) + else: + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) # Wrap up index file. index_file.append_meta_data("total_size", total_size) @@ -223,7 +257,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + from colossalai.utils.safetensors import load_flat + + state_dict = load_flat(shard_file) + else: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) # shard state dict for param_idx, state in state_dict.items(): for k, v in state.items(): diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 4d752f3e6e9c..070f42ad21ee 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -10,7 +10,7 @@ from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, get_optimizer_state_dict_numl, has_index_file __all__ = ["CheckpointIO"] @@ -213,6 +213,7 @@ def save_optimizer( gather_dtensor=True, prefix: str = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. @@ -229,11 +230,14 @@ def save_optimizer( prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ - - if shard: - self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + if not shard and use_async: + size_per_shard = get_optimizer_state_dict_numl(optimizer) + if shard or use_async: + self.save_sharded_optimizer( + optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async + ) else: - self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async) # ======================================================== # Abstract methods for model loading/saving implementation @@ -326,7 +330,13 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): @abstractmethod def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + use_async: bool = False, ): """ Save optimizer to sharded checkpoint. @@ -340,7 +350,9 @@ def save_sharded_optimizer( """ @abstractmethod - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer to unsharded checkpoint. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6d539cce60c9..629787e68d0e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -24,9 +24,11 @@ SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" STATES_NAME = "pytorch_optim.bin" +SAFE_STATE_NAME = "optimizer.safetensors" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" STATES_INDEX_NAME = "pytorch_optim.bin.index.json" +SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json" GROUP_FILE_NAME = "pytorch_optim_group.bin" # ====================================== @@ -842,14 +844,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False): return weights_name, save_index_file -def get_optimizer_base_filenames(prefix: str = None): +def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False): """ generate base optimizer state filenames """ - states_name = STATES_NAME + states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME states_name = add_prefix(states_name, prefix) - save_index_file = STATES_INDEX_NAME + save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME save_index_file = add_prefix(save_index_file, prefix) param_group_file = GROUP_FILE_NAME @@ -872,3 +874,12 @@ def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): for name, tensor in state_dict.items(): pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu") return pin_mem + + +def get_optimizer_state_dict_numl(optimizer): + total_size = 0 + state_dict = optimizer.state_dict() + for param_group in state_dict["state"].values(): + for param_name, param_tensor in param_group.items(): + total_size += torch.tensor(param_tensor).numel() if param_name == "step" else param_tensor.numel() + return total_size diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 8f9cce246556..d41d346b8217 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -78,7 +78,7 @@ def check_state_dict_equal( v1 = v1.to(v2.dtype) assert_close_loose(v1, v2) else: - assert v1 == v2, f"{v1} not equals to {v2}" + assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}" def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index bf8decd0faff..ad7d3be77d72 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,10 +1,11 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json +import warnings from dataclasses import asdict, dataclass from typing import Dict, List, Optional, Tuple import torch -from safetensors.torch import _TYPES +from safetensors.torch import _TYPES, load_file, safe_open try: from tensornvme.async_file_io import AsyncFileWriter @@ -27,36 +28,93 @@ class PreparedData: offset: int -def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: - sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) +def flatten_dict(nested_dict, parent_key="", separator="^"): + """ + Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. + + nested_dict: The input nested dictionary. + parent_key: The parent key currently being processed. + separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary." + """ + items = [] + for k, v in nested_dict.items(): + new_key = f"{parent_key}{separator}{k}" if parent_key else str(k) + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, separator).items()) + else: + v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v + items.append((new_key, v)) + + return dict(items) + + +def unflatten_dict(flattened_dict, separator="^"): + """ + Restore a flattened dictionary back to a multi-level nested dictionary. + + flattened_dict: The flattened dictionary. + separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary. + """ + nested_dict = {} + for key, value in flattened_dict.items(): + keys = key.split(separator) + try: + keys[0] = int(keys[0]) + except ValueError: + warnings.warn(f"{key[0]} can't convert to integer") + d = nested_dict + for part in keys[:-1]: + if part not in d: + d[part] = {} + d = d[part] + assert isinstance(value, torch.Tensor) + d[keys[-1]] = value + + return nested_dict + + +def prepare( + data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None +) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: + if metadata is not None: + assert isinstance(metadata, dict) + for k, v in metadata.items(): + metadata[k] = json.dumps(v) + assert isinstance(k, str) + assert isinstance(metadata[k], str) tensors = [] tensor_keys = [] - metadata = {} + header = {} offset = 0 - for name, tensor in sorted_data: + if metadata is not None: + header["__metadata__"] = metadata + + for name, tensor in data.items(): n = tensor.numel() * tensor.element_size() tensor_info = TensorInfo( dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) ) offset += n - metadata[name] = asdict(tensor_info) + header[name] = asdict(tensor_info) tensors.append(tensor) tensor_keys.append(name) - metadata_buf = json.dumps(metadata).encode("utf-8") + header_buf = json.dumps(header).encode("utf-8") - extra = (8 - len(metadata_buf) % 8) % 8 - metadata_buf += b" " * extra + extra = (8 - len(header_buf) % 8) % 8 + header_buf += b" " * extra - n = len(metadata_buf) + n = len(header_buf) - return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys + return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys -def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: - prepared_data, tensors, _ = prepare(state_dict) +def save( + f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None +) -> None: + prepared_data, tensors, _ = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer.write(n.to_bytes(8, byteorder="little")) @@ -66,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) +def save_nested( + f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None +) -> None: + flatten_data = flatten_dict(state_dict) + save(f_writer, flatten_data, metadata) + + def move_and_save( f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], @@ -83,3 +148,16 @@ def move_and_save( f_writer.write_tensor(state_dict[name], state_dict_pinned[name]) else: f_writer.write_tensor(state_dict[name]) + + +def load_flat(checkpoint_path): + with safe_open(checkpoint_path, framework="pt") as f: + metadata = f.metadata() + state_dict_load = load_file(checkpoint_path) + state_dict = unflatten_dict(state_dict_load) + if metadata is None: + return state_dict + metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items())) + combined_state_dict = {"state": state_dict} + combined_state_dict.update(metadata) + return combined_state_dict diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index a8e05a25ad28..ede7a583829c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -26,9 +26,10 @@ # only test 2 is fine @clear_cache_before_run() @parameterize("stage", [2]) -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("offload", [False, True]) -def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): +@parameterize("use_async", [False, True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) model = resnet18() @@ -45,8 +46,10 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here + if not shard and use_async: + optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors" booster.save_model(model, model_ckpt_path, shard=shard) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) dist.barrier() @@ -124,7 +127,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo assert torch.equal( working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) ) - + booster.checkpoint_io.synchronize() new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py new file mode 100644 index 000000000000..79435e26e1d4 --- /dev/null +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -0,0 +1,52 @@ +import tempfile +from copy import deepcopy + +import torch + +from colossalai.utils.safetensors import load_flat, save_nested + +try: + from tensornvme.async_file_io import AsyncFileWriter +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") + +from colossalai.testing import check_state_dict_equal + + +def test_save_load(): + with tempfile.TemporaryDirectory() as tempdir: + optimizer_state_dict = { + 0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, + 1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, + 2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, + } + group_dict = {"param_groups": [0, 1, 2]} + metadata = deepcopy(group_dict) + optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" + f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") + + save_nested(f_writer, optimizer_state_dict, metadata) + f_writer.sync_before_step() + f_writer.synchronize() + f_writer.fp.close() + + load_state_dict = load_flat(optimizer_saved_path) + state_dict = load_state_dict["state"] + group = {"param_groups": load_state_dict["param_groups"]} + check_state_dict_equal(optimizer_state_dict, state_dict) + check_state_dict_equal(group_dict, group) + + model_state_dict = { + "module.weight0": torch.rand((1024, 1024)), + "module.weight1": torch.rand((1024, 1024)), + "module.weight2": torch.rand((1024, 1024)), + } + model_saved_path = f"{tempdir}/save_model.safetensors" + f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") + save_nested(f_writer, model_state_dict) + f_writer.sync_before_step() + f_writer.synchronize() + f_writer.fp.close() + + load_state_dict = load_flat(model_saved_path) + check_state_dict_equal(model_state_dict, load_state_dict) From b193f1a243e3af4497d3a410c7529bb008503675 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Nov 2024 18:29:26 +0800 Subject: [PATCH 02/10] fix --- colossalai/checkpoint_io/checkpoint_io_base.py | 6 ++---- colossalai/checkpoint_io/utils.py | 9 --------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 1e92abb96703..9e431d3559ac 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -10,7 +10,7 @@ from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, get_optimizer_state_dict_numl, has_index_file +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file __all__ = ["CheckpointIO"] @@ -230,9 +230,7 @@ def save_optimizer( prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ - if not shard and use_async: - size_per_shard = get_optimizer_state_dict_numl(optimizer) - if shard or use_async: + if shard: self.save_sharded_optimizer( optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async ) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index dc4b123688f4..b8c15c374cf5 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -871,12 +871,3 @@ def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): for name, tensor in state_dict.items(): pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu") return pin_mem - - -def get_optimizer_state_dict_numl(optimizer): - total_size = 0 - state_dict = optimizer.state_dict() - for param_group in state_dict["state"].values(): - for param_name, param_tensor in param_group.items(): - total_size += torch.tensor(param_tensor).numel() if param_name == "step" else param_tensor.numel() - return total_size From 1e3c3968c2aa620d9d52f687e684966238c53a44 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Nov 2024 18:41:17 +0800 Subject: [PATCH 03/10] fix --- colossalai/booster/plugin/low_level_zero_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 9e333ffe73e4..413539ec0ab8 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -136,7 +136,7 @@ def save_unsharded_optimizer( from colossalai.utils.safetensors import save_nested f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") - save_nested(f_writer, state_dict["state"], state_dict["param_groups"]) + save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) self.async_writers.append(f_writer) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) From fcfecd705e7e082314b7b6beb2684eaf00bec0d7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Nov 2024 19:44:18 +0800 Subject: [PATCH 04/10] support pin mem --- .../booster/plugin/low_level_zero_plugin.py | 14 +++++++-- colossalai/zero/low_level/low_level_optim.py | 31 ++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 413539ec0ab8..3321fd701000 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -128,7 +128,12 @@ def save_unsharded_optimizer( # the `state_dict` in LowLevelZeroOptimizer has communication # if only the master rank collect state_dict and save, # the communication on each rank would not match - state_dict = optimizer.state_dict() + if use_async and id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + state_dict = optimizer.state_dict(pinned_state_dicts) if self.coordinator.is_master(): if use_async: from tensornvme.async_file_io import AsyncFileWriter @@ -184,7 +189,12 @@ def save_sharded_optimizer( # state_dict only provide only 'param_groups' state_dict = optimizer.optim.state_dict() # state shard would be handled by the low-level zero optimizer - sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) + if use_async and id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dict=pinned_state_dicts) # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 26fff75fbfdf..8bc024f8c60e 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -770,7 +770,7 @@ def pack_group(group): return {"state": packed_state, "param_groups": param_groups} - def state_dict(self) -> Dict: + def state_dict(self, pinned_state_dicts=None) -> Dict: """Return a state_dict same with DDP Returns: @@ -779,15 +779,23 @@ def state_dict(self) -> Dict: zero_state = dict() device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): + if pinned_state_dicts and param not in pinned_state_dicts: + pinned_state_dicts[param] = {} zero_state[param] = copy.deepcopy(state) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": + if pinned_state_dicts and k not in pinned_state_dicts[param]: + pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu") working_param = self.master_to_working_param[id(param)] pg = self.param_to_pg[working_param] gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) - param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu() - zero_state[param][k] = param_state + param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) + if pinned_state_dicts: + pinned_state_dicts[param][k].copy_(param_state) + zero_state[param][k] = pinned_state_dicts[param][k] + else: + zero_state[param][k] = param_state.cpu() states_dict = self._pack_state(zero_state) @@ -822,7 +830,7 @@ def load_state_dict(self, state_dict: Dict): self.optim.load_state_dict(zero_state_dict) - def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + def state_dict_shard(self, max_shard_size: int = 1024, pinned_state_dicts=None) -> Iterator[Tuple[Dict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Only include the 'state' in state_dict. @@ -847,18 +855,27 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) - + if pinned_state_dicts and param_idx not in pinned_state_dicts: + pinned_state_dicts[param_idx] = {} master_param = idx2master[param_idx] working_param = self.master_to_working_param[id(master_param)] pg = self.param_to_pg[working_param] for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": + if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: + pinned_state_dicts[param_idx][k] = torch.empty_like( + working_param, pin_memory=True, device="cpu" + ) state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) - state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu() + state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) + if pinned_state_dicts: + pinned_state_dicts[param_idx][k].copy_(state_tensor) + current_block[k] = pinned_state_dicts[param_idx][k] + else: + current_block[k] = state_tensor.cpu() current_block_size += state_tensor.numel() - current_block[k] = state_tensor if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: yield ret_block, ret_block_size From f67c59b30bfa623fad3c526b4fcf1d64e33c769d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 15 Nov 2024 20:36:53 +0800 Subject: [PATCH 05/10] Update low_level_zero_plugin.py --- colossalai/booster/plugin/low_level_zero_plugin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 3321fd701000..4b94a8715e82 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -128,8 +128,9 @@ def save_unsharded_optimizer( # the `state_dict` in LowLevelZeroOptimizer has communication # if only the master rank collect state_dict and save, # the communication on each rank would not match - if use_async and id(optimizer) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(optimizer)] = {} + if use_async: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None @@ -189,8 +190,9 @@ def save_sharded_optimizer( # state_dict only provide only 'param_groups' state_dict = optimizer.optim.state_dict() # state shard would be handled by the low-level zero optimizer - if use_async and id(optimizer) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(optimizer)] = {} + if use_async: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None From 511a66441d8c0b8653d28bce30857116d9404017 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Nov 2024 14:11:07 +0800 Subject: [PATCH 06/10] fix --- colossalai/booster/plugin/gemini_plugin.py | 12 ++- .../booster/plugin/low_level_zero_plugin.py | 2 +- colossalai/booster/plugin/torch_ddp_plugin.py | 9 ++- .../booster/plugin/torch_fsdp_plugin.py | 12 ++- .../checkpoint_io/general_checkpoint_io.py | 2 + .../hybrid_parallel_checkpoint_io.py | 5 +- colossalai/testing/comparison.py | 5 +- colossalai/utils/safetensors.py | 35 ++++++++- requirements/requirements-test.txt | 2 +- .../test_low_level_zero_checkpoint_io.py | 12 +-- .../test_safetensors_async_io.py | 77 ++++++++++++++++++- 11 files changed, 154 insertions(+), 19 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 35c51da0105a..30c1257ef14c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -94,7 +94,9 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = assert isinstance(model, GeminiDDP), "Please boost the model before loading!" super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save unsharded optimizer state dict to checkpoint. After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. @@ -178,7 +180,13 @@ def load_sharded_model( return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) def save_sharded_optimizer( - self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + self, + optimizer: GeminiOptimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + use_async: bool = False, ): """ Save sharded optimizer state dict to checkpoint folder. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 3321fd701000..b9aae5ad6e25 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -194,7 +194,7 @@ def save_sharded_optimizer( pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None - sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dict=pinned_state_dicts) + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 09830a2f9873..07be5b0516f6 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -52,7 +52,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str) assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" super().load_unsharded_optimizer(optimizer, checkpoint) - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer to checkpoint but only on master process. """ @@ -113,13 +115,16 @@ def save_sharded_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save optimizer to sharded checkpoint but only on master process. """ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): - super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) + super().save_sharded_optimizer( + optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async + ) def load_sharded_optimizer( self, diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d309370dd620..b80d6d4b6eb8 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -67,7 +67,9 @@ def save_unsharded_model( full_model_state = model.state_dict() utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer to checkpoint but only on master process. """ @@ -157,7 +159,13 @@ def load_sharded_model( model.unwrap().load_state_dict(fsdp_state_dict, strict=False) def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int + self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + use_async: bool = False, ): """ Save optimizer to checkpoint but only on master process. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 580be91ca0d8..a2d1dd158afa 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -98,6 +98,7 @@ def save_sharded_optimizer( gather_dtensor: bool, prefix: str, size_per_shard: int, + use_async: bool = False, ): """ Save sharded optimizer checkpoint under the given checkpointing path. @@ -155,6 +156,7 @@ def save_unsharded_optimizer( optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, + use_async: bool = False, ): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 49d4f35f9cc0..d66171c58ccd 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -416,6 +416,7 @@ def save_sharded_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save sharded optimizer checkpoint under the given checkpointing path. @@ -725,7 +726,9 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False + ): """ Save optimizer state dict to a file with given path. diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index d41d346b8217..1ee8a492dd8a 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,4 +1,4 @@ -from typing import Any, List, OrderedDict +from typing import Any, List, OrderedDict, Tuple import torch import torch.distributed as dist @@ -78,6 +78,9 @@ def check_state_dict_equal( v1 = v1.to(v2.dtype) assert_close_loose(v1, v2) else: + if isinstance(v1, Tuple) and not isinstance(v2, Tuple): + v2 = tuple(v2) + print("key", k) assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}" diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index ad7d3be77d72..25bcf387b8b4 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -28,6 +28,32 @@ class PreparedData: offset: int +# class TupleEncoder(json.JSONEncoder): +# def default(self, obj): +# if isinstance(obj, tuple): +# return {"__tuple__": True, "items": list(obj)} +# return super().default(obj) + + +# def tuple_decoder(d): +# if "__tuple__" in d: +# return tuple(d["items"]) +# return d +# 自定义 JSON 编码器,处理 tuple +class NestedTupleEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, tuple): + return {"__tuple__": True, "items": list(obj)} + return super().default(obj) + + +# 自定义解码器,处理 tuple +def nested_tuple_decoder(d): + if "__tuple__" in d: + return tuple(d["items"]) + return d + + def flatten_dict(nested_dict, parent_key="", separator="^"): """ Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. @@ -79,7 +105,8 @@ def prepare( if metadata is not None: assert isinstance(metadata, dict) for k, v in metadata.items(): - metadata[k] = json.dumps(v) + metadata[k] = json.dumps(v, cls=NestedTupleEncoder) + print("metadata[k]", type(metadata[k])) assert isinstance(k, str) assert isinstance(metadata[k], str) @@ -157,7 +184,11 @@ def load_flat(checkpoint_path): state_dict = unflatten_dict(state_dict_load) if metadata is None: return state_dict - metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items())) + print("metadata", metadata) + metadata = dict( + map(lambda item: (item[0], json.loads(item[1], object_hook=nested_tuple_decoder)), metadata.items()) + ) + # metadata = json.loads(metadata, object_hook=tuple_decoder) combined_state_dict = {"state": state_dict} combined_state_dict.update(metadata) return combined_state_dict diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 3fcf53e1858e..0d4c26db5b17 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -13,7 +13,7 @@ triton requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja -flash_attn +flash_attn==2.5.0 datasets pydantic ray diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 5e10bf3e2a1c..4dc8e9f54fdd 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -51,17 +51,15 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us model_ckpt_path = f"{model_ckpt_path}.pt" if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" + # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here + if not shard and use_async: + optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors" booster.save_model( model, model_ckpt_path, shard=shard, use_async=use_async, ) - - # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here - if not shard and use_async: - optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors" - booster.save_model(model, model_ckpt_path, shard=shard) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) booster.checkpoint_io._sync_d2h() @@ -88,7 +86,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) + print(optimizer.optim.state_dict()) + check_state_dict_equal(optimizer.optim.state_dict()["state"], new_optimizer.optim.state_dict()["state"]) torch.cuda.empty_cache() @@ -144,6 +143,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + # print(optimizer.optim.state_dict()["param_groups"], new_optimizer.optim.state_dict()["param_groups"]) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) except Exception as e: diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py index 79435e26e1d4..31c69e961e30 100644 --- a/tests/test_checkpoint_io/test_safetensors_async_io.py +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -20,7 +20,82 @@ def test_save_load(): 1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, 2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, } - group_dict = {"param_groups": [0, 1, 2]} + # group_dict = {"param_groups": [0, 1, 2]} + group_dict = { + "param_groups": [ + { + "lr": 0.001, + "betas": (0.9, 0.999), + "eps": 1e-08, + "weight_decay": 0, + "bias_correction": True, + "params": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + ], + } + ] + } metadata = deepcopy(group_dict) optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") From 659b9a3e3093ac54861fe1ad77f485609698eaa4 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Nov 2024 14:17:02 +0800 Subject: [PATCH 07/10] fix --- colossalai/testing/comparison.py | 1 - colossalai/utils/safetensors.py | 35 ++----------------- .../test_low_level_zero_checkpoint_io.py | 2 -- 3 files changed, 2 insertions(+), 36 deletions(-) diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 1ee8a492dd8a..4cbb01163e5a 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -80,7 +80,6 @@ def check_state_dict_equal( else: if isinstance(v1, Tuple) and not isinstance(v2, Tuple): v2 = tuple(v2) - print("key", k) assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}" diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 25bcf387b8b4..ad7d3be77d72 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -28,32 +28,6 @@ class PreparedData: offset: int -# class TupleEncoder(json.JSONEncoder): -# def default(self, obj): -# if isinstance(obj, tuple): -# return {"__tuple__": True, "items": list(obj)} -# return super().default(obj) - - -# def tuple_decoder(d): -# if "__tuple__" in d: -# return tuple(d["items"]) -# return d -# 自定义 JSON 编码器,处理 tuple -class NestedTupleEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, tuple): - return {"__tuple__": True, "items": list(obj)} - return super().default(obj) - - -# 自定义解码器,处理 tuple -def nested_tuple_decoder(d): - if "__tuple__" in d: - return tuple(d["items"]) - return d - - def flatten_dict(nested_dict, parent_key="", separator="^"): """ Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. @@ -105,8 +79,7 @@ def prepare( if metadata is not None: assert isinstance(metadata, dict) for k, v in metadata.items(): - metadata[k] = json.dumps(v, cls=NestedTupleEncoder) - print("metadata[k]", type(metadata[k])) + metadata[k] = json.dumps(v) assert isinstance(k, str) assert isinstance(metadata[k], str) @@ -184,11 +157,7 @@ def load_flat(checkpoint_path): state_dict = unflatten_dict(state_dict_load) if metadata is None: return state_dict - print("metadata", metadata) - metadata = dict( - map(lambda item: (item[0], json.loads(item[1], object_hook=nested_tuple_decoder)), metadata.items()) - ) - # metadata = json.loads(metadata, object_hook=tuple_decoder) + metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items())) combined_state_dict = {"state": state_dict} combined_state_dict.update(metadata) return combined_state_dict diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 4dc8e9f54fdd..22e4c7bf0dab 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -86,7 +86,6 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - print(optimizer.optim.state_dict()) check_state_dict_equal(optimizer.optim.state_dict()["state"], new_optimizer.optim.state_dict()["state"]) torch.cuda.empty_cache() @@ -143,7 +142,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - # print(optimizer.optim.state_dict()["param_groups"], new_optimizer.optim.state_dict()["param_groups"]) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) except Exception as e: From 4f2844161cbe236300bd5b45f93c2fbc6765817e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Nov 2024 14:36:20 +0800 Subject: [PATCH 08/10] fix --- requirements/requirements-test.txt | 2 +- tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 0d4c26db5b17..3fcf53e1858e 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -13,7 +13,7 @@ triton requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja -flash_attn==2.5.0 +flash_attn datasets pydantic ray diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 22e4c7bf0dab..e814d3ec68a3 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -86,7 +86,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.optim.state_dict()["state"], new_optimizer.optim.state_dict()["state"]) + check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) torch.cuda.empty_cache() From 55f4e43c910eeb0d72ebd3c47c4aafd3bc5b169d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Nov 2024 14:56:05 +0800 Subject: [PATCH 09/10] fix --- colossalai/zero/low_level/low_level_optim.py | 6 ++++-- .../test_checkpoint_io/test_low_level_zero_checkpoint_io.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8bc024f8c60e..e3c301640867 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -770,7 +770,7 @@ def pack_group(group): return {"state": packed_state, "param_groups": param_groups} - def state_dict(self, pinned_state_dicts=None) -> Dict: + def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict: """Return a state_dict same with DDP Returns: @@ -830,7 +830,9 @@ def load_state_dict(self, state_dict: Dict): self.optim.load_state_dict(zero_state_dict) - def state_dict_shard(self, max_shard_size: int = 1024, pinned_state_dicts=None) -> Iterator[Tuple[Dict, int]]: + def state_dict_shard( + self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None + ) -> Iterator[Tuple[Dict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Only include the 'state' in state_dict. diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index e814d3ec68a3..05dfcce4f674 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -51,7 +51,6 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us model_ckpt_path = f"{model_ckpt_path}.pt" if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" - # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here if not shard and use_async: optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors" booster.save_model( @@ -60,8 +59,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us shard=shard, use_async=use_async, ) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) + # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() dist.barrier() From a4a053227d161eb1ff08e79d15b09efd1c385450 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Nov 2024 16:18:01 +0800 Subject: [PATCH 10/10] fix --- colossalai/checkpoint_io/moe_checkpoint.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 4cb0f300f65e..3b07856ca06c 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -369,6 +369,7 @@ def save_sharded_optimizer( gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024, + use_async: bool = False, ): """ Save sharded optimizer checkpoint under the given checkpointing path. @@ -729,7 +730,13 @@ def save_unsharded_model( dist.barrier() # Copied from colossalai.moe - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool, + use_async: bool = False, + ): """ Save optimizer state dict to a file with given path.