Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .gemini_plugin import GeminiPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin

__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
259 changes: 259 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper

from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPCheckpointIO

__all__ = ['LowLevelZeroPlugin']


def _convert_to_fp16(x):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.half()
return x


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)


class LowLevelZeroModel(ModelWrapper):

def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.convert_inputs = (precision == 'fp16')
module = zero_model_wrapper(module, zero_stage=stage)
if precision == 'fp16':
module = module.half()
module = module.to(get_current_device())
self.module = module

def forward(self, *args, **kwargs):
if self.convert_inputs:
args = tree_map(_convert_to_fp16, args)
kwargs = tree_map(_convert_to_fp16, kwargs)
return super().forward(*args, **kwargs)


class LowLevelZeroOptimizer(OptimizerWrapper):

def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)

def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)

def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')

def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')


class LowLevelZeroPlugin(Plugin):
"""
Plugin for low level zero.

Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin()

>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Args:
strage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12.
communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True.
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
"""

def __init__(
self,
stage: int = 1,
precision: str = 'fp16',
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
verbose: bool = False,
) -> None:

assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'

self.rank = dist.get_rank()
self.world_size = dist.get_world_size()

self.stage = stage
self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.verbose = verbose

def support_no_sync(self) -> bool:
return False

def control_precision(self) -> bool:
return True

def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']

def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ['cuda']

def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.

Note:
1. Evaluation datasets should not be passed to this function.

Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.

Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)

# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)

return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)

def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
return True

def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
1 change: 1 addition & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
# 5. support layer drop
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
Expand Down
3 changes: 2 additions & 1 deletion tests/kit/model_zoo/diffusers/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3)

identity_output = lambda x: x
clip_vision_model_output = lambda x: dict(pooler_output=x[1])


def data_clip_model():
Expand Down Expand Up @@ -65,7 +66,7 @@ def data_clip_vision():
model_zoo.register(name='diffusers_clip_vision_model',
model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()),
data_gen_fn=data_clip_vision,
output_transform_fn=identity_output)
output_transform_fn=clip_vision_model_output)

model_zoo.register(name='diffusers_unet2d_model',
model_fn=diffusers.UNet2DModel,
Expand Down
16 changes: 8 additions & 8 deletions tests/kit/model_zoo/torchaudio/torchaudio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import torch
import torchaudio.models as tm

Expand Down Expand Up @@ -101,13 +103,11 @@ def tacotron_data_gen_fn():
mel_specgram_lengths=mel_specgram_lengths)


model_zoo.register(
name='torchaudio_tacotron',
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(
spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]),
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='torchaudio_tacotron',
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
model_attribute=ModelAttribute(has_control_flow=True))


def wav2vec_data_gen_fn():
Expand All @@ -118,7 +118,7 @@ def wav2vec_data_gen_fn():


model_zoo.register(name='torchaudio_wav2vec2_base',
model_fn=tm.wav2vec2_base,
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
data_gen_fn=wav2vec_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
Expand Down
8 changes: 4 additions & 4 deletions tests/kit/model_zoo/torchvision/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def swin_s():


# special output transform fn
google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
) else dict(output=x)
google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs
) else dict(output=x)
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs
) else dict(output=x)

model_zoo.register(name='torchvision_alexnet',
model_fn=tm.alexnet,
Expand Down
Loading