Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def step(self, closure=None):
# * math.sqrt(bias_correction2) / bias_correction1
step_size = group['lr']

weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
weight_norm = p.data.pow(2).sum().sqrt()

adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
Expand Down
14 changes: 14 additions & 0 deletions examples/vit-b16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Overview

Here is an example of training ViT-B/16 on Imagenet-1K. We use 8x A100 in this example. For simplicity and speed, we didn't apply `RandAug` and we just used `Mixup`. With `LAMB` optimizer, we can scale the batch size to 32K with a little accuracy loss.

# How to run
Using slurm:
```shell
srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py
```

# Results

![Loss Curve](./loss.jpeg)
![Accuracy](./acc.jpeg)
Binary file added examples/vit-b16/acc.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
112 changes: 112 additions & 0 deletions examples/vit-b16/dataloader/imagenet_dali_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali.tfrecord as tfrec
import torch
import numpy as np


class DaliDataloader(DALIClassificationIterator):
def __init__(self,
tfrec_filenames,
tfrec_idx_filenames,
shard_id=0,
num_shards=1,
batch_size=128,
num_threads=4,
resize=256,
crop=224,
prefetch=2,
training=True,
gpu_aug=False,
cuda=True,
mixup_alpha=0.0):
self.mixup_alpha = mixup_alpha
self.training = training
pipe = Pipeline(batch_size=batch_size,
num_threads=num_threads,
device_id=torch.cuda.current_device() if cuda else None,
seed=1024)
with pipe:
inputs = fn.readers.tfrecord(
path=tfrec_filenames,
index_path=tfrec_idx_filenames,
random_shuffle=training,
shard_id=shard_id,
num_shards=num_shards,
initial_fill=10000,
read_ahead=True,
prefetch_queue_depth=prefetch,
name='Reader',
features={
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
})
images = inputs["image/encoded"]

if training:
images = fn.decoders.image(images,
device='mixed' if gpu_aug else 'cpu',
output_type=types.RGB)
images = fn.random_resized_crop(images,
size=crop,
device='gpu' if gpu_aug else 'cpu')
flip_lr = fn.random.coin_flip(probability=0.5)
else:
# decode jpeg and resize
images = fn.decoders.image(images,
device='mixed' if gpu_aug else 'cpu',
output_type=types.RGB)
images = fn.resize(images,
device='gpu' if gpu_aug else 'cpu',
resize_x=resize,
resize_y=resize,
dtype=types.FLOAT,
interp_type=types.INTERP_TRIANGULAR)
flip_lr = False

# center crop and normalise
images = fn.crop_mirror_normalize(images,
dtype=types.FLOAT,
crop=(crop, crop),
mean=[127.5],
std=[127.5],
mirror=flip_lr)
label = inputs["image/class/label"] - 1 # 0-999
# LSG: element_extract will raise exception, let's flatten outside
# label = fn.element_extract(label, element_map=0) # Flatten
if cuda: # transfer data to gpu
pipe.set_outputs(images.gpu(), label.gpu())
else:
pipe.set_outputs(images, label)

pipe.build()
last_batch_policy = 'DROP' if training else 'PARTIAL'
super().__init__(pipe, reader_name="Reader",
auto_reset=True,
last_batch_policy=last_batch_policy)

def __iter__(self):
# if not reset (after an epoch), reset; if just initialize, ignore
if self._counter >= self._size or self._size < 0:
self.reset()
return self

def __next__(self):
data = super().__next__()
img, label = data[0]['data'], data[0]['label']
label = label.squeeze()
if self.mixup_alpha > 0.0:
if self.training:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
idx = torch.randperm(img.size(0)).to(img.device)
img = lam * img + (1 - lam) * img[idx, :]
label_a, label_b = label, label[idx]
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
label = (label_a, label_b, lam)
else:
label = (label, label, torch.ones(
1, device=img.device, dtype=img.dtype))
return (img,), label
return (img,), (label,)
15 changes: 15 additions & 0 deletions examples/vit-b16/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from colossalai.registry import HOOKS
from colossalai.trainer import BaseHook
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode


@HOOKS.register_module
class TotalBatchsizeHook(BaseHook):
def __init__(self, trainer, priority: int = 2) -> None:
super().__init__(trainer, priority)

def before_train(self):
total_batch_size = gpc.config.BATCH_SIZE * \
gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
Binary file added examples/vit-b16/loss.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions examples/vit-b16/mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn
from colossalai.registry import LOSSES

@LOSSES.register_module
class MixupLoss(nn.Module):
def __init__(self, loss_fn_cls):
super().__init__()
self.loss_fn = loss_fn_cls()

def forward(self, inputs, *args):
targets_a, targets_b, lam = args
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
70 changes: 70 additions & 0 deletions examples/vit-b16/train_dali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import glob
import os
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import set_global_multitimer_status
from dataloader.imagenet_dali_dataloader import DaliDataloader


def build_dali_train():
root = gpc.config.dali.root
train_pat = os.path.join(root, 'train/*')
train_idx_pat = os.path.join(root, 'idx_files/train/*')
return DaliDataloader(
sorted(glob.glob(train_pat)),
sorted(glob.glob(train_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=gpc.config.dali.gpu_aug,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)


def build_dali_test():
root = gpc.config.dali.root
val_pat = os.path.join(root, 'validation/*')
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
return DaliDataloader(
sorted(glob.glob(val_pat)),
sorted(glob.glob(val_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
# gpu_aug=gpc.config.dali.gpu_aug,
gpu_aug=False,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)


def main():
engine, train_dataloader, test_dataloader = colossalai.initialize(
train_dataloader=build_dali_train,
test_dataloader=build_dali_test
)
logger = get_global_dist_logger()
set_global_multitimer_status(True)
timer = colossalai.utils.get_global_multitimer()
trainer = Trainer(engine=engine,
verbose=True,
timer=timer)

trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=1
)


if __name__ == '__main__':
main()
78 changes: 78 additions & 0 deletions examples/vit-b16/vit-b16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from colossalai.engine import AMP_TYPE
from torch.nn import CrossEntropyLoss
from mixup import MixupLoss
from hooks import TotalBatchsizeHook
from colossalai.registry import MODELS
from timm.models import vit_base_patch16_224

MODELS.register_module(vit_base_patch16_224)

LOG_NAME = 'vit-b16-1k-32k-mixup-light2'
# ViT Base
BATCH_SIZE = 256
DROP_RATE = 0.1
NUM_EPOCHS = 300

parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
)

optimizer = dict(
type='Lamb',
lr=1.8e-2,
weight_decay=0.1,
)


loss = dict(
type='MixupLoss',
loss_fn_cls=CrossEntropyLoss
)

model = dict(
type='vit_base_patch16_224',
drop_rate=DROP_RATE,
)

hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='AccuracyHook'),
dict(type='LossHook'),
dict(type='TotalBatchsizeHook'),
dict(type='TensorboardHook', log_dir=f'./tb_logs/{LOG_NAME}'),
dict(type='SaveCheckpointHook', interval=1,
checkpoint_dir=f'./ckpt/{LOG_NAME}'),
# dict(type='LoadCheckpointHook', epoch=10,
# checkpoint_dir=f'./ckpt/{LOG_NAME}'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=150
)
),
]

fp16 = dict(
mode=AMP_TYPE.TORCH,
)


logging = dict(
root_path=f"./logs/{LOG_NAME}"
)

dali = dict(
root='./dataset/ILSVRC2012_1k',
gpu_aug=True,
mixup_alpha=0.2
)

engine = dict(
schedule=None,
gradient_handlers=None,
gradient_accumulation=16,
gradient_clipping=1.0,
)