Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colossalai/booster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .accelerator import Accelerator
from .booster import Booster
from .environment_table import EnvironmentTable
from .plugin import Plugin
57 changes: 37 additions & 20 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import GeneralCheckpointIO

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
Expand Down Expand Up @@ -61,19 +63,21 @@ def __init__(self,
self.plugin = plugin

# set accelerator
if self.plugin and self.plugin.control_device:
if self.plugin and self.plugin.control_device():
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
self.accelerator = Accelerator(device)

# set precision
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
self.mixed_precision = None
if self.plugin and self.plugin.control_precision():
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
else:
# validate and set precision
if isinstance(MixedPrecision, str):
if isinstance(mixed_precision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
Expand All @@ -84,6 +88,11 @@ def __init__(self,
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)

if self.plugin is not None and self.plugin.control_checkpoint_io():
self.checkpoint_io = self.plugin.get_checkpoint_io()
else:
self.checkpoint_io = GeneralCheckpointIO()

def boost(
self,
model: nn.Module,
Expand All @@ -109,12 +118,13 @@ def boost(
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)

if self.plugin and not self.plugin.control_device:
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
model = self.accelerator.configure(model)

if self.mixed_precision and self.plugin and not self.plugin.control_precision:
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand All @@ -140,18 +150,25 @@ def no_sync(self, model: nn.Module) -> contextmanager:
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)

def save(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
self.checkpoint_io.load_model(model, checkpoint, strict)

def load(self,
obj: Union[nn.Module, Optimizer, LRScheduler],
path_like: str,
plan: str = 'torch',
**kwargs) -> None:
# TODO: implement this method
pass
def save_model(self,
model: nn.Module,
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
self.checkpoint_io.load_optimizer(optimizer, checkpoint)

def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)

def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
18 changes: 0 additions & 18 deletions colossalai/booster/environment_table.py

This file was deleted.

3 changes: 0 additions & 3 deletions colossalai/booster/interface/__init__.py

This file was deleted.

12 changes: 7 additions & 5 deletions colossalai/booster/mixed_precision/fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch import Tensor
from torch.optim import Optimizer

from ..interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper

from .mixed_precision_base import MixedPrecision

__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
Expand Down Expand Up @@ -45,7 +46,9 @@ def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss.backward(*args, **kwargs)

def step(self, *args, **kwargs) -> Optional[float]:
return self.scaler.step(self.optim, *args, **kwargs)
out = self.scaler.step(self.optim, *args, **kwargs)
self.scaler.update()
return out

def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss)
Expand All @@ -67,7 +70,7 @@ def clip_grad_by_norm(self,
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)


class TorchAMPModule(nn.Module):
class TorchAMPModule(ModelWrapper):
"""
Module wrapper for mixed precision training in FP16 using PyTorch AMP.

Expand All @@ -76,8 +79,7 @@ class TorchAMPModule(nn.Module):
"""

def __init__(self, module: nn.Module):
super().__init__()
self.module = module
super().__init__(module)

def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from ..interface import OptimizerWrapper
from colossalai.interface import OptimizerWrapper


class MixedPrecision(ABC):
Expand Down
22 changes: 16 additions & 6 deletions colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,30 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from colossalai.booster.interface import OptimizerWrapper
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper

__all__ = ['Plugin']


class Plugin(ABC):

@property
@abstractmethod
def supported_devices(self) -> List[str]:
pass

@property
@abstractmethod
def supported_precisions(self) -> List[str]:
pass

@property
@abstractmethod
def control_precision(self) -> bool:
pass

@property
@abstractmethod
def control_device(self) -> bool:
pass

@property
@abstractmethod
def support_no_sync(self) -> bool:
pass
Expand All @@ -49,3 +45,17 @@ def configure(
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# implement this method
pass

@abstractmethod
def control_checkpoint_io(self) -> bool:
"""
Whether the plugin controls the checkpoint io
"""
pass

@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass
61 changes: 59 additions & 2 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,61 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from colossalai.booster.interface import OptimizerWrapper
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper

from .plugin_base import Plugin

__all__ = ['TorchDDPPlugin']


class TorchDDPCheckpointIO(GeneralCheckpointIO):

def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()

def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)


class TorchDDPModel(ModelWrapper):

def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)

def unwrap(self):
return self.module.module


class TorchDDPPlugin(Plugin):
"""
Plugin for PyTorch DDP.
Expand Down Expand Up @@ -138,10 +186,19 @@ def configure(
# cast model to cuda
model = model.cuda()

# convert model to sync bn
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)

# wrap the model with PyTorch DDP
model = DDP(model, **self.ddp_kwargs)
model = TorchDDPModel(model, **self.ddp_kwargs)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)

return model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
return True

def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
Loading