Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt

- name: Store Testmon Cache
run: |
Expand Down
149 changes: 149 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Dict, List

import torch
from torch import Tensor
from torch.nn import Parameter
from torch.optim import Optimizer

from colossalai.interface import OptimizerWrapper

from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin


class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):

def __init__(self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.params = working_params

def check_local_overflow(self) -> bool:
for p in self.params:
if p.grad is not None and not torch.isfinite(p.grad).all():
return True
return False


class MixedPrecisionOptimizer(OptimizerWrapper):

def __init__(self,
optim: Optimizer,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0):
super().__init__(optim)
if precision == 'fp16':
working_params = []
for group in self.optim.param_groups:
for p in group['params']:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif precision == 'bf16':
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f'Unsupported precision: {precision}')
if max_norm > 0.0:
raise NotImplementedError('max_norm is not supported yet.')
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}

# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group['params']:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
master_p = p.detach().float()
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group['params'] = master_params

def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)

def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
p.grad = None
self.mixed_precision.pre_zero_grad()
return super().zero_grad(*args, **kwargs)

def _unscale_and_clip_grads(self, total_norm: float) -> None:
div_scale = 1.0
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()

if self.max_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.data.mul_(1. / div_scale)

def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.:
return 0.
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
if len(grads) == 0:
return 0.
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()

def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
# prepare grads
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)
22 changes: 13 additions & 9 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -9,11 +9,12 @@
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase

__all__ = ['Booster']

Expand Down Expand Up @@ -144,26 +145,29 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[torch.Tensor], torch.Tensor],
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optimizer,
return_loss: bool = True,
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
# TODO: implement this method
return_outputs: bool = False) -> dict:
# run pipeline forward backward pass
# return loss or outputs if needed
pass
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)

def no_sync(self, model: nn.Module) -> contextmanager:
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Support torch DDP and Low Level ZeRO-1 for now.

Args:
model (nn.Module): The model to be disabled gradient synchronization.
model (nn.Module): The model to be disabled gradient synchronization, for DDP
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1

Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
Expand Down
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin

__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']

import torch
from packaging import version
Expand Down
Loading