diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index a99aa91e73c3..2b9f6fd41002 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -1,16 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- + +import torch from torch.nn import Module from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from colossalai.builder import build_gradient_handler -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) +from colossalai.utils import is_using_ddp, ConditionalContext, is_using_pp +from colossalai.utils.cuda import get_current_device from .schedule import BaseSchedule @@ -71,11 +73,10 @@ def __init__(self, "Training with zero is detected, ZeROGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) - elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size( - ParallelMode.DATA) > 1: + elif is_using_ddp() and is_using_pp(): gradient_handlers = [dict(type='DataParallelGradientHandler')] self._logger.info( - "Data parallel training is detected, DataParallelGradientHandler is automatically " + "Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) @@ -147,17 +148,33 @@ def step(self, # differentiate training and eval with grad accum if self.training: - for i in range(self._grad_accum_size): - output, label, loss = self._schedule.forward_backward_step( - data_iter, self._model, self._criterion, self._optimizer, - forward_only=False, - grad_accum_size=self._grad_accum_size, - return_loss=return_loss) - - if i == self._grad_accum_size - 1: - # all reduce gradients - self.handle_gradient() - self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip) + outputs = [] + labels = [] + loss = torch.zeros(1, device=get_current_device()) + with ConditionalContext(self._model.no_sync(), enable=is_using_ddp() and not is_using_pp()): + for i in range(self._grad_accum_size - 1): + output, label, loss_ = self._schedule.forward_backward_step( + data_iter, self._model, self._criterion, self._optimizer, + forward_only=False, + grad_accum_size=self._grad_accum_size, + return_loss=return_loss) + outputs.append(output) + labels.append(label) + loss.add_(loss_) + output, label, loss_ = self._schedule.forward_backward_step( + data_iter, self._model, self._criterion, self._optimizer, + forward_only=False, + grad_accum_size=self._grad_accum_size, + return_loss=return_loss) + outputs.append(output) + labels.append(label) + loss.add_(loss_) + output = self._accum_outputs(outputs) + label = self._accum_outputs(labels) + # all reduce gradients + self.handle_gradient() + self._schedule.optimizer_step( + self._model, self._optimizer, self._grad_clip) else: output, label, loss = self._schedule.forward_backward_step( data_iter, self._model, self._criterion, self._optimizer, @@ -174,3 +191,7 @@ def step(self, break return output, label, loss + + @staticmethod + def _accum_outputs(tensor_tuples): + return tuple([torch.cat(x) for x in zip(*tensor_tuples)]) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 6806d86eb61c..3c94c1cbfed3 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -11,7 +11,7 @@ import numpy as np import torch from torch.utils.data import DataLoader - +from torch.nn.parallel import DistributedDataParallel as DDP from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger, init_global_dist_logger @@ -22,7 +22,7 @@ build_optimizer_wrapper, build_schedule) from .context import Config, ParallelMode from .core import global_context as gpc -from .utils import get_current_device, sync_model_param_in_dp +from .utils import get_current_device, sync_model_param_in_dp, is_using_ddp, is_using_pp def parse_args(): @@ -276,6 +276,10 @@ def initialize(config: Union[str, dict] = None, model = model.half() logger.info("Model is cast to fp16", ranks=[0]) + if is_using_ddp() and not is_using_pp(): + model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) + logger.info( + 'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0]) # training data if callable(train_dataloader): logger.info( @@ -288,7 +292,7 @@ def initialize(config: Union[str, dict] = None, logger.info('Train dataset is ready.', ranks=[0]) train_dataloader = get_dataloader(train_dataset, - gpc.config.get('seed', 1024), + gpc.config.get('seed', 42), True, **gpc.config.train_data.dataloader, ) diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py index 68531e92a249..f7248bd68fe7 100644 --- a/colossalai/nn/optimizer/lamb.py +++ b/colossalai/nn/optimizer/lamb.py @@ -94,7 +94,7 @@ def step(self, closure=None): # * math.sqrt(bias_correction2) / bias_correction1 step_size = group['lr'] - weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + weight_norm = p.data.pow(2).sum().sqrt() adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) if group['weight_decay'] != 0: diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py index 3c3fdfc43ef8..d4d84dff76f2 100644 --- a/colossalai/trainer/hooks/_log_hook.py +++ b/colossalai/trainer/hooks/_log_hook.py @@ -170,19 +170,23 @@ def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10, - log_eval: bool = True + log_eval: bool = True, + ignore_num_train_steps: int = 0 ) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) set_global_multitimer_status(True) self._global_timer = get_global_multitimer() self._log_eval = log_eval self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() + self.ignore_num_train_steps = ignore_num_train_steps def _get_message(self): msg = [] for timer_name, timer in self._global_timer: last_elapsed_time = timer.get_elapsed_time() if timer.has_history: + if timer_name == 'train-step': + timer._history = timer._history[self.ignore_num_train_steps:] history_mean = timer.get_history_mean() history_sum = timer.get_history_sum() msg.append( @@ -201,7 +205,7 @@ def after_train_epoch(self): if self._is_epoch_to_log() and self._is_rank_to_log: msg = self._get_message() self.logger.info( - f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}') + f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={self.trainer.steps_per_epoch}') def after_test_epoch(self): """Writes log after finishing a testing epoch. diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index f7ef2259bed0..64aafab740e9 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,5 +1,5 @@ from .activation_checkpoint import checkpoint -from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage +from .common import print_rank_0, sync_model_param_in_dp, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage, is_using_ddp, is_using_pp, ConditionalContext from .cuda import get_current_device, synchronize, empty_cache, set_to_cuda from .memory import report_memory_usage from .timer import MultiTimer, Timer @@ -18,5 +18,6 @@ def set_global_multitimer_status(mode: bool): __all__ = ['checkpoint', 'print_rank_0', 'sync_model_param_in_dp', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'get_global_multitimer', 'set_global_multitimer_status', - 'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage' + 'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage', + 'is_using_ddp', 'ConditionalContext', 'is_using_pp' ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index d8c6663ba626..ce6432166328 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist - +from contextlib import contextmanager from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -26,17 +26,37 @@ def sync_model_param_in_dp(model): :param model: A pyTorch nn.model on whose parameters you check the consistency ''' - if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: for param in model.parameters(): ranks = gpc.get_ranks_in_group(ParallelMode.DATA) - dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) + dist.broadcast( + param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) + def is_dp_rank_0(): return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) + def is_tp_rank_0(): return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) + def is_no_pp_or_last_stage(): - return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) \ No newline at end of file + return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) + + +def is_using_ddp(): + return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 + + +def is_using_pp(): + return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 + + +@contextmanager +def ConditionalContext(context_manager, enable=True): + if enable: + with context_manager: + yield + else: + yield diff --git a/configs/vit/vit_2d.py b/configs/vit/vit_2d.py index b771b583e9d9..23ddc8d6cad8 100644 --- a/configs/vit/vit_2d.py +++ b/configs/vit/vit_2d.py @@ -144,7 +144,7 @@ parallel = dict( pipeline=dict(size=1), - tensor=dict(size=4, mode='2d'), + tensor=dict(size=1, mode='2d'), ) # for fp16 training diff --git a/docs/conf.py b/docs/conf.py index b0a57bdbc08b..695477e35fbe 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,6 +24,8 @@ # The full version, including alpha/beta/rc tags release = '0.0.1' +if 'SPHINX_LANG' in os.environ: + root_doc = f'index_{os.environ["SPHINX_LANG"]}' # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be diff --git a/docs/index.rst b/docs/index.rst index f9a6ce444a79..16141b5ead8e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,27 +3,27 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -夸父AI系统(Colossal-AI)开发文档 +Colossal-AI documentation ====================================== .. toctree:: :maxdepth: 1 - :caption: 快速上手指南 + :caption: GETTING STARTED - installation_zh.md - run_demo_zh.md + installation.md + run_demo.md .. toctree:: :maxdepth: 1 - :caption: 个性化您的训练 - - parallelization_zh.md - model_zh.md - trainer_engine_zh.md - amp_zh.md - zero_zh.md - add_your_parallel_zh.md - config_zh.md + :caption: CUSTOMIZE YOUR TRAINING + + parallelization.md + model.md + trainer_engine.md + amp.md + zero.md + add_your_parallel.md + config.md diff --git a/docs/index_en.rst b/docs/index_zh.rst similarity index 62% rename from docs/index_en.rst rename to docs/index_zh.rst index 16141b5ead8e..f9a6ce444a79 100644 --- a/docs/index_en.rst +++ b/docs/index_zh.rst @@ -3,27 +3,27 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Colossal-AI documentation +夸父AI系统(Colossal-AI)开发文档 ====================================== .. toctree:: :maxdepth: 1 - :caption: GETTING STARTED + :caption: 快速上手指南 - installation.md - run_demo.md + installation_zh.md + run_demo_zh.md .. toctree:: :maxdepth: 1 - :caption: CUSTOMIZE YOUR TRAINING - - parallelization.md - model.md - trainer_engine.md - amp.md - zero.md - add_your_parallel.md - config.md + :caption: 个性化您的训练 + + parallelization_zh.md + model_zh.md + trainer_engine_zh.md + amp_zh.md + zero_zh.md + add_your_parallel_zh.md + config_zh.md diff --git a/examples/vit-b16/README.md b/examples/vit-b16/README.md new file mode 100644 index 000000000000..c28c7ed4477b --- /dev/null +++ b/examples/vit-b16/README.md @@ -0,0 +1,40 @@ +# Overview + +Here is an example of training ViT-B/16 on Imagenet-1K with batch size 32K. +We use 8x NVIDIA A100 GPU in this example. + +# How to run +Using [Slurm](https://slurm.schedmd.com/documentation.html): +```shell +srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py +``` + +# Results + +![Loss Curve](./loss.jpeg) +![Accuracy](./acc.jpeg) + +# Details +`vit-b16.py` + +It is a [config file](https://colossalai.org/config.html), which is used by ColossalAI to define all kinds of training arguments, such as the model, dataset, and training method (optimizer, lr_scheduler, epoch, etc.). You can access config content by `gpc.config`. + +In this example, we train the ViT-Base patch 16 model 300 epochs on ImageNet-1K. The batch size is set to 32K through data parallel (4K on each GPU from 16x gradient accumulation with batch size 256). Since the batch size is very large than common usage, leading to convergence difficulties, we use a +large batch optimizer [LAMB](https://arxiv.org/abs/1904.00962), and we can scale the batch size to 32K with a little accuracy loss. The learning rate and weight decay of the optimizer are set to 1.8e-2 and 0.1, respectively. We use a linear warmup learning rate scheduler and warmup 150 epochs. +We introduce FP16 mixed precision to accelerate training and use gradient clipping to help convergence. +For simplicity and speed, we didn't apply `RandAug` and just used [Mixup](https://arxiv.org/abs/1710.09412) in data augmentation. + +If you have enough computing resources, you can expand this example conveniently with data parallel on a very large scale without gradient accumulation, and finish the training process even within one hour. + + +`imagenet_dali_dataloader.py` +To accelerate the training process, we use [DALI](https://github.com/NVIDIA/DALI) as data loader. Note that it requires the dataset in TFRecord format, avoiding read raw images which reduces efficiency of the file system. + +`train_dali.py` +We build the DALI data loader and train process using Colossal-AI here. + +`mixup.py` +Since we used Mixup, we define mixup loss in this file. + +`hooks.py` +We also define useful hooks to log information help debugging. \ No newline at end of file diff --git a/examples/vit-b16/acc.jpeg b/examples/vit-b16/acc.jpeg new file mode 100755 index 000000000000..43f67fd39167 Binary files /dev/null and b/examples/vit-b16/acc.jpeg differ diff --git a/examples/vit-b16/dataloader/__init__.py b/examples/vit-b16/dataloader/__init__.py new file mode 100755 index 000000000000..e69de29bb2d1 diff --git a/examples/vit-b16/dataloader/imagenet_dali_dataloader.py b/examples/vit-b16/dataloader/imagenet_dali_dataloader.py new file mode 100755 index 000000000000..a39d73e26c36 --- /dev/null +++ b/examples/vit-b16/dataloader/imagenet_dali_dataloader.py @@ -0,0 +1,112 @@ +from nvidia.dali.pipeline import Pipeline +from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali.tfrecord as tfrec +import torch +import numpy as np + + +class DaliDataloader(DALIClassificationIterator): + def __init__(self, + tfrec_filenames, + tfrec_idx_filenames, + shard_id=0, + num_shards=1, + batch_size=128, + num_threads=4, + resize=256, + crop=224, + prefetch=2, + training=True, + gpu_aug=False, + cuda=True, + mixup_alpha=0.0): + self.mixup_alpha = mixup_alpha + self.training = training + pipe = Pipeline(batch_size=batch_size, + num_threads=num_threads, + device_id=torch.cuda.current_device() if cuda else None, + seed=1024) + with pipe: + inputs = fn.readers.tfrecord( + path=tfrec_filenames, + index_path=tfrec_idx_filenames, + random_shuffle=training, + shard_id=shard_id, + num_shards=num_shards, + initial_fill=10000, + read_ahead=True, + prefetch_queue_depth=prefetch, + name='Reader', + features={ + 'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""), + 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1), + }) + images = inputs["image/encoded"] + + if training: + images = fn.decoders.image(images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) + images = fn.random_resized_crop(images, + size=crop, + device='gpu' if gpu_aug else 'cpu') + flip_lr = fn.random.coin_flip(probability=0.5) + else: + # decode jpeg and resize + images = fn.decoders.image(images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) + images = fn.resize(images, + device='gpu' if gpu_aug else 'cpu', + resize_x=resize, + resize_y=resize, + dtype=types.FLOAT, + interp_type=types.INTERP_TRIANGULAR) + flip_lr = False + + # center crop and normalise + images = fn.crop_mirror_normalize(images, + dtype=types.FLOAT, + crop=(crop, crop), + mean=[127.5], + std=[127.5], + mirror=flip_lr) + label = inputs["image/class/label"] - 1 # 0-999 + # LSG: element_extract will raise exception, let's flatten outside + # label = fn.element_extract(label, element_map=0) # Flatten + if cuda: # transfer data to gpu + pipe.set_outputs(images.gpu(), label.gpu()) + else: + pipe.set_outputs(images, label) + + pipe.build() + last_batch_policy = 'DROP' if training else 'PARTIAL' + super().__init__(pipe, reader_name="Reader", + auto_reset=True, + last_batch_policy=last_batch_policy) + + def __iter__(self): + # if not reset (after an epoch), reset; if just initialize, ignore + if self._counter >= self._size or self._size < 0: + self.reset() + return self + + def __next__(self): + data = super().__next__() + img, label = data[0]['data'], data[0]['label'] + label = label.squeeze() + if self.mixup_alpha > 0.0: + if self.training: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + idx = torch.randperm(img.size(0)).to(img.device) + img = lam * img + (1 - lam) * img[idx, :] + label_a, label_b = label, label[idx] + lam = torch.tensor([lam], device=img.device, dtype=img.dtype) + label = (label_a, label_b, lam) + else: + label = (label, label, torch.ones( + 1, device=img.device, dtype=img.dtype)) + return (img,), label + return (img,), (label,) diff --git a/examples/vit-b16/hooks.py b/examples/vit-b16/hooks.py new file mode 100644 index 000000000000..b6c306ed7184 --- /dev/null +++ b/examples/vit-b16/hooks.py @@ -0,0 +1,15 @@ +from colossalai.registry import HOOKS +from colossalai.trainer import BaseHook +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +@HOOKS.register_module +class TotalBatchsizeHook(BaseHook): + def __init__(self, trainer, priority: int = 2) -> None: + super().__init__(trainer, priority) + + def before_train(self): + total_batch_size = gpc.config.BATCH_SIZE * \ + gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA) + self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0]) diff --git a/examples/vit-b16/loss.jpeg b/examples/vit-b16/loss.jpeg new file mode 100755 index 000000000000..a16c333cc8e9 Binary files /dev/null and b/examples/vit-b16/loss.jpeg differ diff --git a/examples/vit-b16/mixup.py b/examples/vit-b16/mixup.py new file mode 100644 index 000000000000..822bc8659df0 --- /dev/null +++ b/examples/vit-b16/mixup.py @@ -0,0 +1,12 @@ +import torch.nn as nn +from colossalai.registry import LOSSES + +@LOSSES.register_module +class MixupLoss(nn.Module): + def __init__(self, loss_fn_cls): + super().__init__() + self.loss_fn = loss_fn_cls() + + def forward(self, inputs, *args): + targets_a, targets_b, lam = args + return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b) diff --git a/examples/vit-b16/train_dali.py b/examples/vit-b16/train_dali.py new file mode 100644 index 000000000000..fed39c3cc3f8 --- /dev/null +++ b/examples/vit-b16/train_dali.py @@ -0,0 +1,70 @@ +import glob +import os +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_global_dist_logger +from colossalai.trainer import Trainer +from colossalai.utils import set_global_multitimer_status +from dataloader.imagenet_dali_dataloader import DaliDataloader + + +def build_dali_train(): + root = gpc.config.dali.root + train_pat = os.path.join(root, 'train/*') + train_idx_pat = os.path.join(root, 'idx_files/train/*') + return DaliDataloader( + sorted(glob.glob(train_pat)), + sorted(glob.glob(train_idx_pat)), + batch_size=gpc.config.BATCH_SIZE, + shard_id=gpc.get_local_rank(ParallelMode.DATA), + num_shards=gpc.get_world_size(ParallelMode.DATA), + training=True, + gpu_aug=gpc.config.dali.gpu_aug, + cuda=True, + mixup_alpha=gpc.config.dali.mixup_alpha + ) + + +def build_dali_test(): + root = gpc.config.dali.root + val_pat = os.path.join(root, 'validation/*') + val_idx_pat = os.path.join(root, 'idx_files/validation/*') + return DaliDataloader( + sorted(glob.glob(val_pat)), + sorted(glob.glob(val_idx_pat)), + batch_size=gpc.config.BATCH_SIZE, + shard_id=gpc.get_local_rank(ParallelMode.DATA), + num_shards=gpc.get_world_size(ParallelMode.DATA), + training=False, + # gpu_aug=gpc.config.dali.gpu_aug, + gpu_aug=False, + cuda=True, + mixup_alpha=gpc.config.dali.mixup_alpha + ) + + +def main(): + engine, train_dataloader, test_dataloader = colossalai.initialize( + train_dataloader=build_dali_train, + test_dataloader=build_dali_test + ) + logger = get_global_dist_logger() + set_global_multitimer_status(True) + timer = colossalai.utils.get_global_multitimer() + trainer = Trainer(engine=engine, + verbose=True, + timer=timer) + + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=gpc.config.NUM_EPOCHS, + hooks_cfg=gpc.config.hooks, + display_progress=True, + test_interval=1 + ) + + +if __name__ == '__main__': + main() diff --git a/examples/vit-b16/vit-b16.py b/examples/vit-b16/vit-b16.py new file mode 100755 index 000000000000..ac51e226ef81 --- /dev/null +++ b/examples/vit-b16/vit-b16.py @@ -0,0 +1,78 @@ +from colossalai.engine import AMP_TYPE +from torch.nn import CrossEntropyLoss +from mixup import MixupLoss +from hooks import TotalBatchsizeHook +from colossalai.registry import MODELS +from timm.models import vit_base_patch16_224 + +MODELS.register_module(vit_base_patch16_224) + +LOG_NAME = 'vit-b16-1k-32k-mixup-light2' +# ViT Base +BATCH_SIZE = 256 +DROP_RATE = 0.1 +NUM_EPOCHS = 300 + +parallel = dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), +) + +optimizer = dict( + type='Lamb', + lr=1.8e-2, + weight_decay=0.1, +) + + +loss = dict( + type='MixupLoss', + loss_fn_cls=CrossEntropyLoss +) + +model = dict( + type='vit_base_patch16_224', + drop_rate=DROP_RATE, +) + +hooks = [ + dict(type='LogMetricByEpochHook'), + dict(type='AccuracyHook'), + dict(type='LossHook'), + dict(type='TotalBatchsizeHook'), + dict(type='TensorboardHook', log_dir=f'./tb_logs/{LOG_NAME}'), + dict(type='SaveCheckpointHook', interval=1, + checkpoint_dir=f'./ckpt/{LOG_NAME}'), + # dict(type='LoadCheckpointHook', epoch=10, + # checkpoint_dir=f'./ckpt/{LOG_NAME}'), + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=150 + ) + ), +] + +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + + +logging = dict( + root_path=f"./logs/{LOG_NAME}" +) + +dali = dict( + root='./dataset/ILSVRC2012_1k', + gpu_aug=True, + mixup_alpha=0.2 +) + +engine = dict( + schedule=None, + gradient_handlers=None, + gradient_accumulation=16, + gradient_clipping=1.0, +) diff --git a/setup.py b/setup.py index 9949e9eeadc2..8541b0a6ce3a 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ import os import subprocess import sys -import warnings import torch from setuptools import setup, find_packages @@ -23,13 +22,36 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version( + cuda_dir) + torch_binary_major = torch.version.cuda.split(".")[0] + torch_binary_minor = torch.version.cuda.split(".")[1] + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " + + "not match the version used to compile Pytorch binaries. " + + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk).") + + +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines()] + + if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). print('\nWarning: Torch did not find available GPUs on this system.\n', 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' + 'By default, Colossal-AI will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n' 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n' 'If you wish to cross-compile for a single specific architecture,\n' @@ -46,66 +68,12 @@ def get_cuda_bare_metal_version(cuda_dir): TORCH_MINOR = int(torch.__version__.split('.')[1]) if TORCH_MAJOR == 0 and TORCH_MINOR < 4: - raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" + + raise RuntimeError("Colossal-AI requires Pytorch 0.4 or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/") cmdclass = {} ext_modules = [] -extras = {} -if "--pyprof" in sys.argv: - string = "\n\nPyprof has been moved to its own dedicated repository and will " + \ - "soon be removed from Apex. Please visit\n" + \ - "https://github.com/NVIDIA/PyProf\n" + \ - "for the latest version." - warnings.warn(string, DeprecationWarning) - with open('requirements.txt') as f: - required_packages = f.read().splitlines() - extras['pyprof'] = required_packages - try: - sys.argv.remove("--pyprof") - except: - pass -else: - warnings.warn( - "Option --pyprof not specified. Not installing PyProf dependencies!") - -if "--cuda_ext" in sys.argv: - if TORCH_MAJOR == 0: - raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, " - "found torch.__version__ = {}".format(torch.__version__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version( - cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): - raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " + - "not match the version used to compile Pytorch binaries. " + - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + - "In some cases, a minor-version mismatch will not cause later errors: " + - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk).") - - # Set up macros for forward/backward compatibility hack around # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # and @@ -123,6 +91,10 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 if "--cuda_ext" in sys.argv: + if TORCH_MAJOR == 0: + raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, " + "found torch.__version__ = {}".format(torch.__version__)) + sys.argv.remove("--cuda_ext") if CUDA_HOME is None: @@ -145,17 +117,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): # '--resource-usage', '--use_fast_math'] + version_dependent_macros})) -# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): - generator_flag = ['-DOLD_GENERATOR'] - - -def fetch_requirements(path): - with open(path, 'r') as fd: - return [r.strip() for r in fd.readlines()] - install_requires = fetch_requirements('requirements/requirements.txt') @@ -170,6 +131,5 @@ def fetch_requirements(path): description='An integrated large-scale model training system with efficient parallelization techniques', ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - extras_require=extras, install_requires=install_requires, )