Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ jobs:
- name: Copy testmon cache
run: | # branch name may contain slash, we need to replace it with space
export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /")
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then
[ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ] && cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ] && [ ! -z "$(ls -A /github/home/testmon_cache/${MAIN_BRANCH})" ]; then
cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
fi
env:
MAIN_BRANCH: ${{ github.event.master_branch }}
Expand All @@ -60,12 +60,15 @@ jobs:
defaults:
run:
shell: bash
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Copy testmon cache
run: | # branch name may contain slash, we need to replace it with space
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
if [ -d "/github/home/testmon_cache/${BASE}" ]; then
[ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ] && mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
if [ -d "/github/home/testmon_cache/${BASE}" ] and [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then
mkdir -p /github/home/testmon_cache/_pull && cp -p -r "/github/home/testmon_cache/${BASE}" /github/home/testmon_cache/_pull/${PR_NUMBER}
fi
env:
PR_NUMBER: ${{ github.event.number }}
Expand All @@ -83,6 +86,9 @@ jobs:
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
runs-on: ubuntu-latest
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -140,6 +146,9 @@ jobs:
defaults:
run:
shell: bash
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Checkout TensorNVMe
uses: actions/checkout@v2
Expand All @@ -150,7 +159,9 @@ jobs:

- name: Restore TensorNVMe Cache
run: |
[ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then
cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
fi

- name: Install TensorNVMe
run: |
Expand All @@ -173,7 +184,9 @@ jobs:
if: needs.detect.outputs.anyExtensionFileChanged != 'true'
run: |
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then
cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
fi

- name: Install Colossal-AI
run: |
Expand Down Expand Up @@ -264,8 +277,8 @@ jobs:
if: github.event.pull_request.merged == true
run: | # branch name may contain slash, we need to replace it with space
export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ]; then
[ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ] && cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then
cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
fi
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ jobs:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v3
- id: set-matrix
Expand Down Expand Up @@ -40,6 +43,9 @@ jobs:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
timeout-minutes: 120
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Install dependencies
run: |
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/doc_check_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ jobs:
github.event.pull_request.draft == false &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v2

Expand All @@ -31,6 +34,9 @@ jobs:
github.event.pull_request.draft == false &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v2
with:
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/doc_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ jobs:
outputs:
any_changed: ${{ steps.changed-files.outputs.any_changed }}
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
name: Detect changed example files
steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -59,6 +62,9 @@ jobs:
defaults:
run:
shell: bash
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- name: Checkout ColossalAI-Documentation
uses: actions/checkout@v2
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/example_check_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ jobs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
name: Detect changed example files
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v3
with:
Expand Down Expand Up @@ -77,6 +80,9 @@ jobs:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
concurrency:
group: ${{ github.head_ref }}
cancel-in-progress: false
steps:
- uses: actions/checkout@v3

Expand Down
57 changes: 40 additions & 17 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
Expand Down Expand Up @@ -97,10 +98,10 @@ def __init__(self,
def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
Expand Down Expand Up @@ -165,11 +166,11 @@ def no_sync(self, model: nn.Module) -> contextmanager:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)

def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.

Args:
model (nn.Module): A model boosted by Booster.
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
strict (bool, optional): whether to strictly enforce that the keys
Expand All @@ -179,24 +180,34 @@ def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
self.checkpoint_io.load_model(model, checkpoint, strict)

def save_model(self,
model: nn.Module,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""Save model to checkpoint.

Args:
model (nn.Module): A model boosted by Booster.
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
self.checkpoint_io.save_model(model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.
Expand All @@ -205,22 +216,34 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)

def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
"""Save optimizer to checkpoint.
Warning: Saving sharded optimizer checkpoint is not supported yet.
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.

Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""Save lr scheduler to checkpoint.
Expand Down
8 changes: 5 additions & 3 deletions colossalai/booster/mixed_precision/fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ def __init__(self,

def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion
7 changes: 4 additions & 3 deletions colossalai/booster/mixed_precision/mixed_precision_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple

import torch.nn as nn
from torch.optim import Optimizer
Expand All @@ -15,7 +15,8 @@ class MixedPrecision(ABC):
@abstractmethod
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass
31 changes: 19 additions & 12 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -76,14 +76,14 @@ def save_sharded_model(self,
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
Expand All @@ -99,8 +99,11 @@ def save_sharded_model(self,
save_state_dict(shard, checkpoint_file_path, use_safetensors)

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The model is going to be split to checkpoint shards. "

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

Expand Down Expand Up @@ -271,11 +274,11 @@ def supported_devices(self) -> List[str]:
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:

if not isinstance(model, ModelWrapper):
# convert model to sync bn
Expand All @@ -290,8 +293,12 @@ def configure(
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
Loading