Skip to content
Merged

f #99

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
585a2e4
Merge pull request #44 from hpcaitech/main
jamesthesnake May 14, 2023
86b899d
Merge pull request #40 from jamesthesnake/ra
jamesthesnake May 21, 2023
9fd12c6
Merge pull request #48 from jamesthesnake/ra
jamesthesnake May 27, 2023
75a485e
Merge pull request #51 from jamesthesnake/co
jamesthesnake Jun 1, 2023
b9d75dd
Merge pull request #52 from jamesthesnake/ra
jamesthesnake Jun 1, 2023
f8cf731
Merge pull request #59 from jamesthesnake/co
jamesthesnake Jun 5, 2023
948f6c9
Merge pull request #67 from jamesthesnake/co
jamesthesnake Jun 14, 2023
514d78b
Merge pull request #71 from jamesthesnake/co
jamesthesnake Jun 16, 2023
582e767
Merge pull request #75 from jamesthesnake/co
jamesthesnake Jun 19, 2023
030d5b1
Merge pull request #79 from jamesthesnake/l
jamesthesnake Jun 30, 2023
a0bb7fe
Merge pull request #84 from jamesthesnake/ra
jamesthesnake Jul 5, 2023
2d08156
Merge pull request #88 from jamesthesnake/l
jamesthesnake Jul 8, 2023
7421f84
Merge pull request #92 from jamesthesnake/ra
jamesthesnake Jul 10, 2023
91b0f5d
Merge pull request #96 from jamesthesnake/co
jamesthesnake Jul 14, 2023
9a4842c
revise shardformer readme (#4246)
CjhHa1 Jul 17, 2023
7ff11b5
[example] add llama pretraining (#4257)
binmakeswell Jul 17, 2023
4b97754
[Kernels] added triton-implemented of self attention for colossal-ai …
tiandiao123 Jul 18, 2023
fc5cef2
[lazy] support init on cuda (#4269)
ver217 Jul 19, 2023
c6f6005
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
Jul 21, 2023
02192a6
[ci] support testmon core pkg change detection (#4305)
ver217 Jul 21, 2023
c9f729e
Merge pull request #97 from hpcaitech/main
jamesthesnake Jul 21, 2023
6a5ca19
Merge pull request #98 from jamesthesnake/ra
jamesthesnake Jul 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt

- name: Store Testmon Cache
run: |
Expand Down
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
</div>

## Latest News
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
Expand All @@ -49,6 +50,7 @@
<li>
<a href="#Parallel-Training-Demo">Parallel Training Demo</a>
<ul>
<li><a href="#LLaMA">LLaMA</a></li>
<li><a href="#GPT-3">GPT-3</a></li>
<li><a href="#GPT-2">GPT-2</a></li>
<li><a href="#BERT">BERT</a></li>
Expand Down Expand Up @@ -216,6 +218,15 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)

## Parallel Training Demo

### LLaMA
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
</p>

- 65-billion-parameter large model pretraining accelerated by 38%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)

### GPT-3
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/GPT3-v5.png" width=700/>
Expand Down
131 changes: 107 additions & 24 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import logging
import os
import warnings
Expand All @@ -12,11 +13,19 @@
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

from .dp_plugin_base import DPPluginBase
Expand All @@ -37,7 +46,7 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, this must be called on all processes.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
Expand All @@ -54,7 +63,7 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, this must be called on all processes.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
Expand All @@ -76,7 +85,8 @@ def save_sharded_model(self,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
Expand All @@ -86,36 +96,32 @@ def save_sharded_model(self,

state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
if not self.coordinator.is_master():
continue
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)

checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

index_file.append_meta_data("total_size", total_size)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

def load_sharded_model(self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False):
"""
load shard model, load model from multiple files
Load shard model, load model from multiple files.
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

Expand All @@ -125,16 +131,93 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()

assert isinstance(optimizer, ZeroOptimizer)

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)

# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = optimizer.get_param_groups_for_saving()
torch.save(param_groups, group_file_path)

# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)

# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=is_master,
use_safetensors=False)

# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(Baizhou): To be implemented.
pass

if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()

assert isinstance(optimizer, ZeroOptimizer)

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)

# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()

optimizer.optimizer_loading_epilogue()

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)


class GeminiModel(ModelWrapper):
Expand Down
38 changes: 18 additions & 20 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple

import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer

Expand All @@ -16,7 +17,6 @@
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
Expand All @@ -25,6 +25,7 @@
load_states_into_optimizer,
save_param_groups,
save_state_dict,
save_state_dict_shards,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
Expand Down Expand Up @@ -122,15 +123,13 @@ def save_sharded_optimizer(
save_param_groups(state_dict, group_file_path)

# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
Expand Down Expand Up @@ -172,18 +171,17 @@ def save_sharded_model(self,
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)

weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors)

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
Expand Down
38 changes: 38 additions & 0 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding=utf-8
import os
import re
from collections import abc as container_abcs
from collections import defaultdict
Expand Down Expand Up @@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim


def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is master.
use_safetensors (bool): Whether to use safetensors to save checkpoint.

Returns:
int: the total size of shards
'''

total_size = 0
for idx, shard_pair in enumerate(sharded_state_dict):
if not is_master:
continue
shard, current_size = shard_pair
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)

# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)

return total_size


def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Expand Down
Loading