Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict_shards,
save_config_file,
save_state_dict,
save_state_dict_shards,
Expand All @@ -29,7 +29,6 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

Expand Down Expand Up @@ -350,11 +349,9 @@ def load_sharded_optimizer(

# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in load_state_dict_shards(
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
):
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
Expand Down
10 changes: 2 additions & 8 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_shards,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
Expand Down Expand Up @@ -276,13 +276,7 @@ def load_sharded_optimizer(

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat

state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
Expand Down
10 changes: 3 additions & 7 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def load_sharded_model(
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

fsdp_state_dict = {}
for shard_file in checkpoint_files:
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
fsdp_state_dict.update(state_dict)

with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
Expand Down Expand Up @@ -388,11 +388,7 @@ def load_sharded_optimizer(
# Load param
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
fsdp_optim_state.update(state_dict_shard)

fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
Expand Down
11 changes: 3 additions & 8 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
get_optimizer_base_filenames,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_state_dict_shards,
load_states_into_optimizer,
save_config_file,
save_param_groups,
Expand Down Expand Up @@ -94,11 +94,7 @@ def load_sharded_optimizer(

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
Expand Down Expand Up @@ -295,8 +291,7 @@ def load_sharded_model(
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
Expand Down
36 changes: 34 additions & 2 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -21,7 +21,7 @@
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import _flatten_optim_state_dict
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -972,3 +972,35 @@ def create_pinned_state_dict(
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)


def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
if is_optim:
if path.endswith(".safetensors"):
state_dict = load_flat(path)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors)
return state_dict


def load_state_dict_shards(
checkpoint_files: List[str],
is_optim: bool,
use_safetensors: bool,
low_cpu_mem_mode: bool = True,
prefetch: int = 3,
) -> Generator[dict, None, None]:
if low_cpu_mem_mode:
for shard_file in checkpoint_files:
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
yield state_dict
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
futures = []
for shard_file in checkpoint_files:
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import gc

from colossalai.accelerator import get_accelerator


def pytest_runtest_setup(item):
# called for running each test in 'a' directory
accelerator = get_accelerator()
accelerator.empty_cache()
gc.collect()