From c5e04f5f3b29798a40ff45ec7586d91a5c8d94d6 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 6 Dec 2024 16:56:19 +0800 Subject: [PATCH] fix async io --- colossalai/checkpoint_io/general_checkpoint_io.py | 3 +-- colossalai/checkpoint_io/utils.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2545806775a4..851e41b4cac5 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,8 +8,6 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.utils.safetensors import move_and_save - from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -55,6 +53,7 @@ def save_unsharded_model( if use_async: from tensornvme.async_file_io import AsyncFileWriter + from colossalai.utils.safetensors import move_and_save writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 77b9faa0bcbf..6dc3fe6ea7e7 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,7 +19,6 @@ to_global, to_global_for_customized_distributed_tensor, ) -from colossalai.utils.safetensors import move_and_save SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -290,6 +289,7 @@ def async_save_state_dict_shards( Returns: int: the total size of shards """ + from colossalai.utils.safetensors import move_and_save from tensornvme.async_file_io import AsyncFileWriter total_size = 0