Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions examples/tutorial/large_batch_optimizer/README.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
# Comparison of Large Batch Training Optimization

## 🚀Quick Start
Run with synthetic data
```bash
colossalai run --nproc_per_node 4 train.py --config config.py -s
```
## Table of contents

- [Overview](#-overview)
- [Quick Start](#-quick-start)

## Prepare Dataset
## 📚 Overview

We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
This example lets you to quickly try out the large batch training optimization provided by Colossal-AI. We use synthetic dataset to go through the process, thus, you don't need to prepare any dataset. You can try out the `Lamb` and `Lars` optimizers from Colossal-AI with the following code.

```bash
export DATA=/path/to/data
```python
from colossalai.nn.optimizer import Lamb, Lars
```

You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command.
## 🚀 Quick Start

1. Install PyTorch

2. Install the dependencies.

```bash
pip install -r requirements.txt
```

## Run on 2*2 device mesh
3. Run the training scripts with synthetic data.

```bash
# run with cifar10
colossalai run --nproc_per_node 4 train.py --config config.py
# run on 4 GPUs
# run with lars
colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lars

# run with synthetic dataset
colossalai run --nproc_per_node 4 train.py --config config.py -s
# run with lamb
colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lamb
```
26 changes: 3 additions & 23 deletions examples/tutorial/large_batch_optimizer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,11 @@
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 10
WARMUP_EPOCHS = 3
NUM_EPOCHS = 2
WARMUP_EPOCHS = 1

# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 512
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
NUM_CLASSES = 1000
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token

# parallel setting
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

parallel = dict(
pipeline=2,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
NUM_CLASSES = 10

fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0

# pipeline config
NUM_MICRO_BATCHES = parallel['pipeline']
5 changes: 3 additions & 2 deletions examples/tutorial/large_batch_optimizer/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
colossalai >= 0.1.12
torch >= 1.8.1
colossalai
torch
titans
8 changes: 8 additions & 0 deletions examples/tutorial/large_batch_optimizer/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
set -euxo pipefail

pip install -r requirements.txt

# run test
colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars
colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb
76 changes: 18 additions & 58 deletions examples/tutorial/large_batch_optimizer/train.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import os

import torch
from titans.dataloader.cifar10 import build_cifar
from titans.model.vit.vit import _create_vit_model
import torch.nn as nn
from torchvision.models import resnet18
from tqdm import tqdm

import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import Lamb, Lars
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import get_dataloader, is_using_pp


class DummyDataloader():
Expand Down Expand Up @@ -45,7 +39,10 @@ def __len__(self):
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
parser.add_argument('--optimizer',
choices=['lars', 'lamb'],
help="Choose your large-batch optimizer",
required=True)
args = parser.parse_args()

# launch from torch
Expand All @@ -55,59 +52,22 @@ def main():
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])

if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)

use_pipeline = is_using_pp()

# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=10,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)

if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)

# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")

# create dataloaders
root = os.environ.get('DATA', '../data/')
if args.synthetic:
train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
else:
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create synthetic dataloaders
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)

# build model
model = resnet18(num_classes=gpc.config.NUM_CLASSES)

# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
criterion = nn.CrossEntropyLoss()

# create optimizer
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
if args.optimizer == "lars":
optim_cls = Lars
elif args.optimizer == "lamb":
optim_cls = Lamb
optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
Expand Down