Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 72 additions & 43 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -24,29 +24,31 @@ class Booster:
Booster is a high-level API for training neural networks. It provides a unified interface for
training with different precision, accelerator, and plugin.

Examples:
```python
colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)

model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)

for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids, attention_mask)
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
```

```python
# Following is pseudocode

colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)

model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)

for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
```

Args:
device (str or torch.device): The device to run the training. Default: None.
Expand All @@ -60,7 +62,7 @@ class Booster:

def __init__(self,
device: Optional[str] = None,
mixed_precision: Union[MixedPrecision, str] = None,
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None) -> None:
if plugin is not None:
assert isinstance(
Expand Down Expand Up @@ -110,14 +112,19 @@ def boost(
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.

Args:
model (nn.Module): The model to be boosted.
optimizer (Optimizer): The optimizer to be boosted.
criterion (Callable): The criterion to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
model (nn.Module): Convert model into a wrapped model for distributive training.
The model might be decorated or partitioned by plugin's strategy after execution of this method.
optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
criterion (Callable, optional): The function that calculates loss. Defaults to None.
dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.

Returns:
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
Expand All @@ -138,10 +145,10 @@ def boost(
return model, optimizer, criterion, dataloader, lr_scheduler

def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
"""Backward pass.
"""Execution of backward during training step.

Args:
loss (torch.Tensor): The loss to be backpropagated.
loss (torch.Tensor): The loss for backpropagation.
optimizer (Optimizer): The optimizer to be updated.
"""
# TODO(frank lee): implement this method with plugin
Expand All @@ -153,9 +160,31 @@ def execute_pipeline(self,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
# run pipeline forward backward pass
# return loss or outputs if needed
return_outputs: bool = False) -> Dict[str, Any]:
"""
Execute forward & backward when utilizing pipeline parallel.
Return loss or Huggingface style model outputs if needed.

Warning: This function is tailored for the scenario of pipeline parallel.
As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
when doing pipeline parallel training with booster, which will cause unexpected errors.

Args:
data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
1. wrap the dataloader to iterator through: iter(dataloader)
2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.

Returns:
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
"""
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
Expand All @@ -175,7 +204,7 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model, optimizer)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.

Args:
Expand All @@ -195,15 +224,15 @@ def save_model(self,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
use_safetensors: bool = False) -> None:
"""Save model to checkpoint.

Args:
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
Expand All @@ -218,7 +247,7 @@ def save_model(self,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.

Args:
Expand All @@ -237,7 +266,7 @@ def save_optimizer(self,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
size_per_shard: int = 1024) -> None:
"""
Save optimizer to checkpoint.

Expand All @@ -254,7 +283,7 @@ def save_optimizer(self,
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.

Args:
Expand All @@ -263,7 +292,7 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)

def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Load lr scheduler from checkpoint.

Args:
Expand Down
27 changes: 19 additions & 8 deletions docs/source/en/basics/booster_api.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Booster API

Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https://github.com/CjhHa1)
Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)

**Prerequisite:**

Expand All @@ -9,32 +9,35 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https:/

**Example Code**

- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)

## Introduction

In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, I will cover how `colossalai.booster` works and what we should take note of.
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, we will cover how `colossalai.booster` works and what we should take note of.

### Plugin

Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:

**_HybridParallelPlugin:_** This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO.

**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.

**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines.

**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.


**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.

More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md).

### API of booster

{{ autodoc:colossalai.booster.Booster }}

## Usage

In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `colossalai.booster` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `booster.boost` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.

A pseudo-code example is like below:

Expand All @@ -48,15 +51,21 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin

def train():
# launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')

# create plugin and objects for training
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

# use booster.boost to wrap the training objects
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)

# do training as normal, except that the backward should be called by booster
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
Expand All @@ -65,14 +74,16 @@ def train():
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()

# checkpointing using booster api
save_path = "./model"
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)

new_model = resnet18()
booster.load_model(new_model, save_path)
```

[more design details](https://github.com/hpcaitech/ColossalAI/discussions/3046)
For more design details please see [this page](https://github.com/hpcaitech/ColossalAI/discussions/3046).

<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->
2 changes: 1 addition & 1 deletion docs/source/en/basics/booster_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ We've introduced the [Booster API](./booster_api.md) in the previous tutorial. I

{{ autodoc:colossalai.booster.Booster.save_model }}

Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers).
Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers), so you can use huggingface `from_pretrained` method to load model from our sharded checkpoint.

{{ autodoc:colossalai.booster.Booster.load_model }}

Expand Down
Loading