Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ jobs:

- name: Execute Unit Testing
run: |
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
Expand Down
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer):


def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})

Expand Down
19 changes: 12 additions & 7 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
load_shard_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
Expand Down Expand Up @@ -204,6 +205,7 @@ def save_sharded_model(self,
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
Expand All @@ -219,17 +221,18 @@ def save_sharded_model(self,
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)

# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
use_safetensors=use_safetensors,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
Expand All @@ -251,6 +254,7 @@ def save_sharded_model(self,
final_index_file.append_weight_map(weight, weight_filename)

final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
Expand Down Expand Up @@ -423,15 +427,16 @@ def save_sharded_optimizer(self,
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)

# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
is_master=control_saving,
use_pp_format=True)

if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
Expand Down
81 changes: 74 additions & 7 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model

from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
Expand Down Expand Up @@ -228,22 +228,25 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False) -> int:
use_safetensors: bool = False,
use_pp_format: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is master.
use_safetensors (bool): Whether to use safetensors to save checkpoint.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.

Returns:
int: the total size of shards
'''

total_size = 0
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
if not is_master:
Expand All @@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]

# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
shard_filenames.append(shard_file)
del shard

# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)

return total_size


Expand Down Expand Up @@ -335,6 +342,66 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path)


def clean_folder(checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.

Args:
checkpoint_path (str): Path to the checkpoint directory.
weights_name (str): Decides the prefix of filenames of weight shards.
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
is_master (bool, optional): Whether current rank is main process. Defaults to True.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.

"""
if is_master:
for filename in os.listdir(checkpoint_path):
full_filename = os.path.join(checkpoint_path, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
if not use_pp_format:
reg = re.compile(r"(.*?)-\d{5}")
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
os.remove(full_filename)


def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
"""
Save config.json/generation_config.json if model is a Huggingface pretrained model.
This method can only be called when a model is saved in a sharded way.

Args:
model (nn.Module): The model whose config should be saved if it's a huggingface model.
checkpoint_path (str): Path to the checkpoint directory.
is_master (bool): Whether current rank is main process.
"""
if not isinstance(model, PreTrainedModel):
return

model = unwrap_huggingface_model(model)

# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype(model)
model.config.torch_dtype = str(dtype).split(".")[1]

# Attach architecture to the config
model.config.architectures = [model.__class__.__name__]

# Save the config
if is_master:
model.config.save_pretrained(checkpoint_path)
if model.can_generate():
model.generation_config.save_pretrained(checkpoint_path)


def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
Expand Down Expand Up @@ -709,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int):
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file
129 changes: 129 additions & 0 deletions tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo


def exam_from_pretrained(model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
test_config,
shard=True,
size_per_shard=32):

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss

def _preprocess_data(data):
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
return iter([data])
else:
return {k: v.cuda() for k, v in data.items()}

model = model_fn()
optimizer = Adam((model.parameters()), lr=0.001)
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)

model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else:
output = model(**_preprocess_data(data))
loss = criterion(output)
optimizer.backward(loss)

optimizer.step()

with shared_tempdir() as tempdir:

model_ckpt_path = f"{tempdir}/model"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()

new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path)
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)

Randomizer.reset_index()
torch.cuda.empty_cache()


@clear_cache_before_run()
@parameterize('test_config', [{
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_test()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_huggingface_compatibility(world_size):
spawn(run_dist, world_size)