Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-


import torch
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import is_using_ddp, ConditionalContext, is_using_pp
from colossalai.utils.cuda import get_current_device
from .schedule import BaseSchedule


Expand Down Expand Up @@ -71,11 +73,10 @@ def __init__(self,
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1:
elif is_using_ddp() and is_using_pp():
gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically "
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])

Expand Down Expand Up @@ -147,17 +148,33 @@ def step(self,

# differentiate training and eval with grad accum
if self.training:
for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)

if i == self._grad_accum_size - 1:
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
outputs = []
labels = []
loss = torch.zeros(1, device=get_current_device())
with ConditionalContext(self._model.no_sync(), enable=is_using_ddp() and not is_using_pp()):
for i in range(self._grad_accum_size - 1):
output, label, loss_ = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
outputs.append(output)
labels.append(label)
loss.add_(loss_)
output, label, loss_ = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
outputs.append(output)
labels.append(label)
loss.add_(loss_)
output = self._accum_outputs(outputs)
label = self._accum_outputs(labels)
# all reduce gradients
self.handle_gradient()
self._schedule.optimizer_step(
self._model, self._optimizer, self._grad_clip)
else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
Expand All @@ -174,3 +191,7 @@ def step(self,
break

return output, label, loss

@staticmethod
def _accum_outputs(tensor_tuples):
return tuple([torch.cat(x) for x in zip(*tensor_tuples)])
10 changes: 7 additions & 3 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import torch
from torch.utils.data import DataLoader

from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
Expand All @@ -22,7 +22,7 @@
build_optimizer_wrapper, build_schedule)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
from .utils import get_current_device, sync_model_param_in_dp, is_using_ddp, is_using_pp


def parse_args():
Expand Down Expand Up @@ -276,6 +276,10 @@ def initialize(config: Union[str, dict] = None,
model = model.half()
logger.info("Model is cast to fp16", ranks=[0])

if is_using_ddp() and not is_using_pp():
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
# training data
if callable(train_dataloader):
logger.info(
Expand All @@ -288,7 +292,7 @@ def initialize(config: Union[str, dict] = None,
logger.info('Train dataset is ready.', ranks=[0])

train_dataloader = get_dataloader(train_dataset,
gpc.config.get('seed', 1024),
gpc.config.get('seed', 42),
True,
**gpc.config.train_data.dataloader,
)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def step(self, closure=None):
# * math.sqrt(bias_correction2) / bias_correction1
step_size = group['lr']

weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
weight_norm = p.data.pow(2).sum().sqrt()

adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
Expand Down
8 changes: 6 additions & 2 deletions colossalai/trainer/hooks/_log_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,23 @@ def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 10,
log_eval: bool = True
log_eval: bool = True,
ignore_num_train_steps: int = 0
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
set_global_multitimer_status(True)
self._global_timer = get_global_multitimer()
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
self.ignore_num_train_steps = ignore_num_train_steps

def _get_message(self):
msg = []
for timer_name, timer in self._global_timer:
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
if timer_name == 'train-step':
timer._history = timer._history[self.ignore_num_train_steps:]
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
Expand All @@ -201,7 +205,7 @@ def after_train_epoch(self):
if self._is_epoch_to_log() and self._is_rank_to_log:
msg = self._get_message()
self.logger.info(
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={self.trainer.steps_per_epoch}')

def after_test_epoch(self):
"""Writes log after finishing a testing epoch.
Expand Down
5 changes: 3 additions & 2 deletions colossalai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .activation_checkpoint import checkpoint
from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp, is_using_pp, ConditionalContext
from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda
from .memory import report_memory_usage
from .timer import MultiTimer, Timer
Expand All @@ -18,5 +18,6 @@ def set_global_multitimer_status(mode: bool):
__all__ = ['checkpoint', 'print_rank_0', 'sync_model_param_in_dp', 'get_current_device',
'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer',
'get_global_multitimer', 'set_global_multitimer_status',
'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage'
'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage',
'is_using_ddp', 'ConditionalContext', 'is_using_pp'
]
28 changes: 24 additions & 4 deletions colossalai/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-

import torch.distributed as dist

from contextlib import contextmanager
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc

Expand All @@ -26,17 +26,37 @@ def sync_model_param_in_dp(model):

:param model: A pyTorch nn.model on whose parameters you check the consistency
'''

if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
dist.broadcast(
param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))


def is_dp_rank_0():
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)


def is_tp_rank_0():
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)


def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)


def is_using_ddp():
return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1


def is_using_pp():
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1


@contextmanager
def ConditionalContext(context_manager, enable=True):
if enable:
with context_manager:
yield
else:
yield
2 changes: 1 addition & 1 deletion configs/vit/vit_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@

parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
tensor=dict(size=1, mode='2d'),
)

# for fp16 training
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# The full version, including alpha/beta/rc tags
release = '0.0.1'

if 'SPHINX_LANG' in os.environ:
root_doc = f'index_{os.environ["SPHINX_LANG"]}'
# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
Expand Down
26 changes: 13 additions & 13 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.

夸父AI系统(Colossal-AI)开发文档
Colossal-AI documentation
======================================
.. toctree::
:maxdepth: 1
:caption: 快速上手指南
:caption: GETTING STARTED

installation_zh.md
run_demo_zh.md
installation.md
run_demo.md


.. toctree::
:maxdepth: 1
:caption: 个性化您的训练

parallelization_zh.md
model_zh.md
trainer_engine_zh.md
amp_zh.md
zero_zh.md
add_your_parallel_zh.md
config_zh.md
:caption: CUSTOMIZE YOUR TRAINING

parallelization.md
model.md
trainer_engine.md
amp.md
zero.md
add_your_parallel.md
config.md



Expand Down
26 changes: 13 additions & 13 deletions docs/index_en.rst → docs/index_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.

Colossal-AI documentation
夸父AI系统(Colossal-AI)开发文档
======================================
.. toctree::
:maxdepth: 1
:caption: GETTING STARTED
:caption: 快速上手指南

installation.md
run_demo.md
installation_zh.md
run_demo_zh.md


.. toctree::
:maxdepth: 1
:caption: CUSTOMIZE YOUR TRAINING

parallelization.md
model.md
trainer_engine.md
amp.md
zero.md
add_your_parallel.md
config.md
:caption: 个性化您的训练

parallelization_zh.md
model_zh.md
trainer_engine_zh.md
amp_zh.md
zero_zh.md
add_your_parallel_zh.md
config_zh.md



Expand Down
40 changes: 40 additions & 0 deletions examples/vit-b16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Overview

Here is an example of training ViT-B/16 on Imagenet-1K with batch size 32K.
We use 8x NVIDIA A100 GPU in this example.

# How to run
Using [Slurm](https://slurm.schedmd.com/documentation.html):
```shell
srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py
```

# Results

![Loss Curve](./loss.jpeg)
![Accuracy](./acc.jpeg)

# Details
`vit-b16.py`

It is a [config file](https://colossalai.org/config.html), which is used by ColossalAI to define all kinds of training arguments, such as the model, dataset, and training method (optimizer, lr_scheduler, epoch, etc.). You can access config content by `gpc.config`.

In this example, we train the ViT-Base patch 16 model 300 epochs on ImageNet-1K. The batch size is set to 32K through data parallel (4K on each GPU from 16x gradient accumulation with batch size 256). Since the batch size is very large than common usage, leading to convergence difficulties, we use a
large batch optimizer [LAMB](https://arxiv.org/abs/1904.00962), and we can scale the batch size to 32K with a little accuracy loss. The learning rate and weight decay of the optimizer are set to 1.8e-2 and 0.1, respectively. We use a linear warmup learning rate scheduler and warmup 150 epochs.
We introduce FP16 mixed precision to accelerate training and use gradient clipping to help convergence.
For simplicity and speed, we didn't apply `RandAug` and just used [Mixup](https://arxiv.org/abs/1710.09412) in data augmentation.

If you have enough computing resources, you can expand this example conveniently with data parallel on a very large scale without gradient accumulation, and finish the training process even within one hour.


`imagenet_dali_dataloader.py`
To accelerate the training process, we use [DALI](https://github.com/NVIDIA/DALI) as data loader. Note that it requires the dataset in TFRecord format, avoiding read raw images which reduces efficiency of the file system.

`train_dali.py`
We build the DALI data loader and train process using Colossal-AI here.

`mixup.py`
Since we used Mixup, we define mixup loss in this file.

`hooks.py`
We also define useful hooks to log information help debugging.
Binary file added examples/vit-b16/acc.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Loading