Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings
from contextlib import contextmanager
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -55,50 +55,69 @@ def __init__(self,
device: str = 'cuda',
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
# create acclerator
self.acceleartor = Accelerator(device)
self.acceleartor.set_default_device()

# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
# the user can customize the arguments by passing the precision object
self.mixed_precision = mixed_precision
if plugin is not None:
assert isinstance(
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
self.plugin = plugin
Comment thread
ver217 marked this conversation as resolved.

# set accelerator
if self.plugin and self.plugin.control_device:
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)
self.accelerator = Accelerator(device)

def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler,
dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
# set precision
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
self.mixed_precision = None
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
else:
# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
# the user can customize the arguments by passing the precision object
self.mixed_precision = mixed_precision
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)

def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.

Args:
model (nn.Module): The model to be boosted.
optimizer (Optimizer): The optimizer to be boosted.
criterion (Callable): The criterion to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
"""
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
model = self.acceleartor.configure_model(model)

# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(lsg): Add plugin control logic
# e.g.
# if self.plugin is not None and self.plugin.control_boost:
# ...
# TODO(FrankLeeeee): consider multi-dataloader case
# transform model for mixed precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, lr_scheduler, dataloader
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)

if self.plugin and not self.plugin.control_device:
# transform model for accelerator
model = self.accelerator.configure(model)

if self.mixed_precision and self.plugin and not self.plugin.control_precision:
# transform model for mixed precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)

return model, optimizer, criterion, dataloader, lr_scheduler

def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO: implement this method with plugin
Expand All @@ -117,8 +136,9 @@ def execute_pipeline(self,
pass

def no_sync(self, model: nn.Module) -> contextmanager:
# TODO: implement this method
pass
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)

def save(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
Expand Down
46 changes: 0 additions & 46 deletions colossalai/booster/plugin.py

This file was deleted.

4 changes: 4 additions & 0 deletions colossalai/booster/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin

__all__ = ['Plugin', 'TorchDDPPlugin']
51 changes: 51 additions & 0 deletions colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union

import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from colossalai.booster.interface import OptimizerWrapper

__all__ = ['Plugin']


class Plugin(ABC):

@property
@abstractmethod
def supported_devices(self) -> List[str]:
pass

@property
@abstractmethod
def supported_precisions(self) -> List[str]:
pass

@property
@abstractmethod
def control_precision(self) -> bool:
pass

@property
@abstractmethod
def control_device(self) -> bool:
pass

@property
@abstractmethod
def support_no_sync(self) -> bool:
pass

@abstractmethod
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# implement this method
pass
147 changes: 147 additions & 0 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import random
from typing import Callable, List, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from colossalai.booster.interface import OptimizerWrapper

from .plugin_base import Plugin

__all__ = ['TorchDDPPlugin']


class TorchDDPPlugin(Plugin):
"""
Plugin for PyTorch DDP.

Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchDDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin()

>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False.
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""

def __init__(self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False) -> None:

assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)

def support_no_sync(self) -> bool:
return True

def control_precision(self) -> bool:
return False

def supported_precisions(self) -> List[str]:
return ['fp16', 'fp16_apex', 'bf16', 'fp8']

def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ['cuda']

def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.

Note:
1. Evaluation datasets should not be passed to this function.

Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.

Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)

# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)

return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)

def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# cast model to cuda
model = model.cuda()

# wrap the model with PyTorch DDP
model = DDP(model, **self.ddp_kwargs)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)

return model, optimizer, criterion, dataloader, lr_scheduler
Loading