Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
56 changes: 56 additions & 0 deletions examples/tutorial/new_api/cifar_resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Train ResNet on CIFAR-10 from scratch

## 🚀 Quick Start

This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch.

- Training Arguments
- `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`.
- `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming.
- `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`.
- `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved.
- `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`.

- Eval Arguments
- `-e`, `--epoch`: select the epoch to evaluate
- `-c`, `--checkpoint`: the folder where checkpoints are found

### Install requirements

```bash
pip install -r requirements.txt
```

### Train

```bash
# train with torch DDP with fp32
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32

# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 -p torch_ddp_fp16

# train with low level zero
colossalai run --nproc_per_node 2 train.py -c ./ckpt-low_level_zero -p low_level_zero
```

### Eval

```bash
# evaluate fp32 training
python eval.py -c ./ckpt-fp32 -e 80

# evaluate fp16 mixed precision training
python eval.py -c ./ckpt-fp16 -e 80

# evaluate low level zero training
python eval.py -c ./ckpt-low_level_zero -e 80
```

Expected accuracy performance will be:

| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% |

**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
4 changes: 4 additions & 0 deletions examples/tutorial/new_api/cifar_resnet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
colossalai
torch
torchvision
tqdm
10 changes: 10 additions & 0 deletions examples/tutorial/new_api/cifar_resnet/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -xe

export DATA=/data/scratch/cifar-10

pip install -r requirements.txt

for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do
colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.84 --plugin $plugin
done
210 changes: 210 additions & 0 deletions examples/tutorial/new_api/cifar_resnet/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import argparse
import os
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3


def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase):
# trainsform
transform_train = transforms.Compose(
[transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
transform_test = transforms.ToTensor()

# CIFAR-10 dataset
data_path = os.environ.get('DATA', './data')
with coordinator.priority_execution():
train_dataset = torchvision.datasets.CIFAR10(root=data_path,
train=True,
transform=transform_train,
download=True)
test_dataset = torchvision.datasets.CIFAR10(root=data_path,
train=False,
transform=transform_test,
download=True)

# Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
return train_dataloader, test_dataloader


@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
model.eval()
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
for images, labels in test_dataloader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
dist.all_reduce(correct)
dist.all_reduce(total)
accuracy = correct.item() / total.item()
if coordinator.is_master():
print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %')
return accuracy


def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
for images, labels in pbar:
images = images.cuda()
labels = labels.cuda()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)

# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()

# Print log info
pbar.set_postfix({'loss': loss.item()})


def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
# FIXME(ver217): gemini is not supported resnet now
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'],
help="plugin to use")
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
parser.add_argument('--target_acc',
type=float,
default=None,
help="target accuracy. Raise exception if not reached")
args = parser.parse_args()

# ==============================
# Prepare Checkpoint Directory
# ==============================
if args.interval > 0:
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)

# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
coordinator = DistCoordinator()

# update the learning rate with linear scaling
# old_gpu_num / old_lr = new_gpu_num / new_lr
global LEARNING_RATE
LEARNING_RATE *= coordinator.world_size

# ==============================
# Instantiate Plugin and Booster
# ==============================
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)

booster = Booster(plugin=plugin, **booster_kwargs)

# ==============================
# Prepare Dataloader
# ==============================
train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin)

# ====================================
# Prepare model, optimizer, criterion
# ====================================
# resent50
model = torchvision.models.resnet18(num_classes=10)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE)

# lr scheduler
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)

# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler)

# ==============================
# Resume from checkpoint
# ==============================
if args.resume >= 0:
booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')

# ==============================
# Train model
# ==============================
start_epoch = args.resume if args.resume >= 0 else 0
for epoch in range(start_epoch, NUM_EPOCHS):
train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator)
lr_scheduler.step()

# save checkpoint
if args.interval > 0 and (epoch + 1) % args.interval == 0:
booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')

accuracy = evaluate(model, test_dataloader, coordinator)
if args.target_acc is not None:
assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}'


if __name__ == '__main__':
main()
37 changes: 37 additions & 0 deletions examples/tutorial/new_api/cifar_vit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Train ViT on CIFAR-10 from scratch

## 🚀 Quick Start

This example provides a training script, which provides an example of training ViT on CIFAR10 dataset from scratch.

- Training Arguments
- `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`.
- `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming.
- `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`.
- `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved.
- `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`.

### Install requirements

```bash
pip install -r requirements.txt
```

### Train

```bash
# train with torch DDP with fp32
colossalai run --nproc_per_node 4 train.py -c ./ckpt-fp32

# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 4 train.py -c ./ckpt-fp16 -p torch_ddp_fp16

# train with low level zero
colossalai run --nproc_per_node 4 train.py -c ./ckpt-low_level_zero -p low_level_zero
```

Expected accuracy performance will be:

| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- |
| ViT | 83.00% | 84.03% | 84.00% | 84.43% |
5 changes: 5 additions & 0 deletions examples/tutorial/new_api/cifar_vit/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
colossalai
timm
torch
torchvision
tqdm
10 changes: 10 additions & 0 deletions examples/tutorial/new_api/cifar_vit/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -xe

export DATA=/data/scratch/cifar-10

pip install -r requirements.txt

for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do
colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.83 --plugin $plugin
done
Loading