Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .base import MixedPrecisionMixin
from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin

__all__ = [
'MixedPrecisionMixin',
'FP16MixedPrecisionMixin',
'BF16MixedPrecisionMixin',
]
91 changes: 91 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod

import torch
from torch import Tensor


class MixedPrecisionMixin(ABC):
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.

Attributes:
dtype (torc.dtype): The expected dtype of the gradients.

Examples:
```python
class MyMixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer):
super().__init__(optim)
self.mixed_precision = MixedPrecisionMixin()

def backward(self, loss):
loss = self.mixed_precision.pre_backward(loss)
loss.backward()

def backward_by_grad(self, tensor, grad):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)

def step(self):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
div_scale = self.mixed_precision.get_grad_div_scale()
# maybe clip grad here
# maybe scale grad here
self.optim.step()

def zero_grad(self):
self.mixed_precision.pre_zero_grad()
return self.optim.zero_grad()
```
"""
dtype: torch.dtype

@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
"""Called before backward.

Args:
loss (Tensor): Loss value.

Returns:
Tensor: Loss value (possibly scaled).
"""
pass

@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
"""Called before backward by grad. This is helpful for pipeline parallelism.

Args:
tensor (Tensor): Tensor to backward.
grad (Tensor): Gradient of the tensor.

Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
pass

@abstractmethod
def should_skip_step(self) -> bool:
"""Called before step.

Returns:
bool: Whether to skip the step.
"""
pass

@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad.
"""
pass

@abstractmethod
def get_grad_div_scale(self) -> float:
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.

Returns:
float: A divisor for gradient clipping or step.
"""
pass
23 changes: 23 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from torch import Tensor

from .base import MixedPrecisionMixin


class BF16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.bfloat16

def pre_backward(self, loss: Tensor) -> Tensor:
return loss

def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
return grad

def should_skip_step(self) -> bool:
return False

def pre_zero_grad(self) -> None:
pass

def get_grad_div_scale(self) -> float:
return 1.0
84 changes: 84 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from abc import abstractmethod
from enum import Enum

import torch
import torch.distributed as dist
from torch import Tensor

from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device

from .base import MixedPrecisionMixin


class OptimState(Enum):
SCALED = 0
UNSCALED = 1


class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16

def __init__(self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())

@property
def loss_scale(self) -> float:
return self.grad_scaler.scale.item()

@abstractmethod
def check_local_overflow(self) -> bool:
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.

Returns:
bool: Whether there is overflow in the local process.
"""
pass

def check_overflow(self) -> bool:
# clear previous overflow record
self.found_overflow.fill_(0.0)
if self.check_local_overflow():
self.found_overflow.fill_(1.0)
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
return self.found_overflow.item() > 0

def pre_backward(self, loss: Tensor) -> Tensor:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
return loss

def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
self.optim_state = OptimState.SCALED
return grad

def should_skip_step(self) -> bool:
found_inf = self.check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED
return found_inf

def pre_zero_grad(self) -> None:
pass

def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
self.optim_state = OptimState.UNSCALED
return self.loss_scale
106 changes: 56 additions & 50 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch.distributed as dist
from torch.optim import Optimizer

from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
Expand All @@ -27,6 +31,31 @@
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket


class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):

def __init__(self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store

def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False


class LowLevelZeroOptimizer(ColossalaiOptimizer):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
Expand Down Expand Up @@ -100,17 +129,6 @@ def __init__(
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype

# gradient scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose)
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())

# gradient clipping
self._clip_grad_norm = clip_grad_norm

Expand Down Expand Up @@ -200,14 +218,25 @@ def __init__(
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()

# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()

@property
def dtype(self):
return self._dtype

@property
def loss_scale(self):
return self.grad_scaler.scale

@property
def num_param_groups(self):
return len(self._working_param_groups)
Expand Down Expand Up @@ -392,7 +421,8 @@ def _add_to_reduction_bucket(self, param, reduce_rank=None):
################################

def backward(self, loss, retain_graph=False, sync_grad=True):
loss = self.loss_scale * loss
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)

# finish gradient reduction
Expand All @@ -419,6 +449,8 @@ def zero_grad(self, set_to_none=True):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
if self.mixed_precision_mixin is not None:
self.mixed_precision_mixin.pre_zero_grad()
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
Expand All @@ -435,12 +467,7 @@ def zero_grad(self, set_to_none=True):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'

# check for overflow
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)

# update loss scale if overflow occurs
if found_inf:
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
Expand Down Expand Up @@ -507,41 +534,20 @@ def step(self, closure=None):
# Mixed Precision Utilities #
#############################

def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)

# check for overflow
for group_id in range(len(self._working_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break

# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)

# all-reduce over model parallel group
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)

if self._found_overflow.item() > 0:
return True
else:
return False

def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
div_scale = 1.0
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()

if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
combined_scale = clip * self.loss_scale
div_scale = clip * div_scale

for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
grad.data.mul_(1. / div_scale)

############################
# Gradient Synchronization #
Expand Down
Loading