Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def no_sync(self, model: nn.Module) -> contextmanager:
return self.plugin.no_sync(model)

def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
"""Load model from checkpoint.

Args:
model (nn.Module): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
"""
self.checkpoint_io.load_model(model, checkpoint, strict)

def save_model(self,
Expand All @@ -159,16 +169,58 @@ def save_model(self,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
"""Save model to checkpoint.

Args:
model (nn.Module): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.

Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)

def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
"""Save optimizer to checkpoint.
Warning: Saving sharded optimizer checkpoint is not supported yet.

Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""Save lr scheduler to checkpoint.

Args:
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)

def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""Load lr scheduler from checkpoint.

Args:
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
1 change: 1 addition & 0 deletions docs/sidebars.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"basics/launch_colossalai",
"basics/booster_api",
"basics/booster_plugins",
"basics/booster_checkpoint",
"basics/define_your_config",
"basics/initialize_features",
"basics/engine_trainer",
Expand Down
48 changes: 48 additions & 0 deletions docs/source/en/basics/booster_checkpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Booster Checkpoint

Author: [Hongxin Liu](https://github.com/ver217)

**Prerequisite:**
- [Booster API](./booster_api.md)

## Introduction

We've introduced the [Booster API](./booster_api.md) in the previous tutorial. In this tutorial, we will introduce how to save and load checkpoints using booster.

## Model Checkpoint

{{ autodoc:colossalai.booster.Booster.save_model }}

Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers).

{{ autodoc:colossalai.booster.Booster.load_model }}

Model must be boosted by `colossalai.booster.Booster` before loading. It will detect the checkpoint format automatically, and load in corresponding way.

## Optimizer Checkpoint

> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet.

{{ autodoc:colossalai.booster.Booster.save_optimizer }}

Optimizer must be boosted by `colossalai.booster.Booster` before saving.

{{ autodoc:colossalai.booster.Booster.load_optimizer }}

Optimizer must be boosted by `colossalai.booster.Booster` before loading.

## LR Scheduler Checkpoint

{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}

LR scheduler must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the local path to checkpoint file.

{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}

LR scheduler must be boosted by `colossalai.booster.Booster` before loading. `checkpoint` is the local path to checkpoint file.

## Checkpoint design

More details about checkpoint design can be found in our discussion [A Unified Checkpoint System Design](https://github.com/hpcaitech/ColossalAI/discussions/3339).

<!-- doc-test-command: echo -->
6 changes: 6 additions & 0 deletions docs/source/en/basics/booster_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ We've tested compatibility on some famous models, following models may not be su

Compatibility problems will be fixed in the future.

> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.

### Gemini Plugin

This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](../features/zero_with_chunk.md).

{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}

> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.

### Torch DDP Plugin

More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
Expand All @@ -62,3 +66,5 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/genera
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html).

{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}

<!-- doc-test-command: echo -->
48 changes: 48 additions & 0 deletions docs/source/zh-Hans/basics/booster_checkpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Booster Checkpoint

作者: [Hongxin Liu](https://github.com/ver217)

**前置教程:**
- [Booster API](./booster_api.md)

## 引言

我们在之前的教程中介绍了 [Booster API](./booster_api.md)。在本教程中,我们将介绍如何使用 booster 保存和加载 checkpoint。

## 模型 Checkpoint

{{ autodoc:colossalai.booster.Booster.save_model }}

模型在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存。当 checkpoint 太大而无法保存在单个文件中时,这很有用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容。

{{ autodoc:colossalai.booster.Booster.load_model }}

模型在加载前必须被 `colossalai.booster.Booster` 加速。它会自动检测 checkpoint 格式,并以相应的方式加载。

## 优化器 Checkpoint

> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。

{{ autodoc:colossalai.booster.Booster.save_optimizer }}

优化器在保存前必须被 `colossalai.booster.Booster` 加速。

{{ autodoc:colossalai.booster.Booster.load_optimizer }}

优化器在加载前必须被 `colossalai.booster.Booster` 加速。

## 学习率调度器 Checkpoint

{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}

学习率调度器在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径.

{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}

学习率调度器在加载前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径.

## Checkpoint 设计

有关 Checkpoint 设计的更多详细信息,请参见我们的讨论 [A Unified Checkpoint System Design](https://github.com/hpcaitech/ColossalAI/discussions/3339).

<!-- doc-test-command: echo -->
6 changes: 6 additions & 0 deletions docs/source/zh-Hans/basics/booster_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累

兼容性问题将在未来修复。

> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。

### Gemini 插件

这个插件实现了基于Chunk内存管理和异构内存管理的 Zero-3。它可以训练大型模型而不会损失太多速度。它也不支持局部梯度累积。更多详细信息,请参阅 [Gemini 文档](../features/zero_with_chunk.md).

{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}

> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。

### Torch DDP 插件

更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
Expand All @@ -62,3 +66,5 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
更多详细信息,请参阅 [Pytorch 文档](https://pytorch.org/docs/main/fsdp.html).

{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}

<!-- doc-test-command: echo -->