Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/images/resnet/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
checkpoint
ckpt-fp16
ckpt-fp32
56 changes: 56 additions & 0 deletions examples/images/resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Train ResNet on CIFAR-10 from scratch

## 🚀 Quick Start

This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch.

- Training Arguments
- `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`.
- `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming.
- `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`.
- `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved.
- `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`.

- Eval Arguments
- `-e`, `--epoch`: select the epoch to evaluate
- `-c`, `--checkpoint`: the folder where checkpoints are found

### Install requirements

```bash
pip install -r requirements.txt
```

### Train
The folders will be created automatically.
```bash
# train with torch DDP with fp32
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32

# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 -p torch_ddp_fp16

# train with low level zero
colossalai run --nproc_per_node 2 train.py -c ./ckpt-low_level_zero -p low_level_zero
```

### Eval

```bash
# evaluate fp32 training
python eval.py -c ./ckpt-fp32 -e 80

# evaluate fp16 mixed precision training
python eval.py -c ./ckpt-fp16 -e 80

# evaluate low level zero training
python eval.py -c ./ckpt-low_level_zero -e 80
```

Expected accuracy performance will be:

| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% |

**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
48 changes: 48 additions & 0 deletions examples/images/resnet/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import argparse

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
args = parser.parse_args()

# ==============================
# Prepare Test Dataset
# ==============================
# CIFAR-10 dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())

# Data loader
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

# ==============================
# Load Model
# ==============================
model = torchvision.models.resnet18(num_classes=10).cuda()
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
model.load_state_dict(state_dict)

# ==============================
# Run Evaluation
# ==============================
model.eval()

with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
5 changes: 5 additions & 0 deletions examples/images/resnet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
colossalai
torch
torchvision
tqdm
pytest
12 changes: 12 additions & 0 deletions examples/images/resnet/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
set -xe

export DATA=/data/scratch/cifar-10

pip install -r requirements.txt

# TODO: skip ci test due to time limits, train.py needs to be rewritten.

# for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do
# colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.84 --plugin $plugin
# done
204 changes: 204 additions & 0 deletions examples/images/resnet/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import argparse
import os
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3


def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase):
# transform
transform_train = transforms.Compose(
[transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
transform_test = transforms.ToTensor()

# CIFAR-10 dataset
data_path = os.environ.get('DATA', './data')
with coordinator.priority_execution():
train_dataset = torchvision.datasets.CIFAR10(root=data_path,
train=True,
transform=transform_train,
download=True)
test_dataset = torchvision.datasets.CIFAR10(root=data_path,
train=False,
transform=transform_test,
download=True)

# Data loader
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_dataloader, test_dataloader


@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
model.eval()
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
for images, labels in test_dataloader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
dist.all_reduce(correct)
dist.all_reduce(total)
accuracy = correct.item() / total.item()
if coordinator.is_master():
print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %')
return accuracy


def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
for images, labels in pbar:
images = images.cuda()
labels = labels.cuda()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)

# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()

# Print log info
pbar.set_postfix({'loss': loss.item()})


def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
# FIXME(ver217): gemini is not supported resnet now
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'],
help="plugin to use")
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint")
parser.add_argument('--target_acc',
type=float,
default=None,
help="target accuracy. Raise exception if not reached")
args = parser.parse_args()

# ==============================
# Prepare Checkpoint Directory
# ==============================
if args.interval > 0:
Path(args.checkpoint).mkdir(parents=True, exist_ok=True)

# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
coordinator = DistCoordinator()

# update the learning rate with linear scaling
# old_gpu_num / old_lr = new_gpu_num / new_lr
global LEARNING_RATE
LEARNING_RATE *= coordinator.world_size

# ==============================
# Instantiate Plugin and Booster
# ==============================
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)

booster = Booster(plugin=plugin, **booster_kwargs)

# ==============================
# Prepare Dataloader
# ==============================
train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin)

# ====================================
# Prepare model, optimizer, criterion
# ====================================
# resent50
model = torchvision.models.resnet18(num_classes=10)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE)

# lr scheduler
lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3)

# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler)

# ==============================
# Resume from checkpoint
# ==============================
if args.resume >= 0:
booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth')
booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth')
booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth')

# ==============================
# Train model
# ==============================
start_epoch = args.resume if args.resume >= 0 else 0
for epoch in range(start_epoch, NUM_EPOCHS):
train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator)
lr_scheduler.step()

# save checkpoint
if args.interval > 0 and (epoch + 1) % args.interval == 0:
booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth')
booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth')
booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth')

accuracy = evaluate(model, test_dataloader, coordinator)
if args.target_acc is not None:
assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}'


if __name__ == '__main__':
main()