Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .base import MixedPrecisionMixin
from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin

__all__ = [
'MixedPrecisionMixin',
'FP16MixedPrecisionMixin',
'BF16MixedPrecisionMixin',
]
91 changes: 91 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod

import torch
from torch import Tensor


class MixedPrecisionMixin(ABC):
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.

Attributes:
dtype (torc.dtype): The expected dtype of the gradients.

Examples:
```python
class MyMixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer):
super().__init__(optim)
self.mixed_precision = MixedPrecisionMixin()

def backward(self, loss):
loss = self.mixed_precision.pre_backward(loss)
loss.backward()

def backward_by_grad(self, tensor, grad):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)

def step(self):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
div_scale = self.mixed_precision.get_grad_div_scale()
# maybe clip grad here
# maybe scale grad here
self.optim.step()

def zero_grad(self):
self.mixed_precision.pre_zero_grad()
return self.optim.zero_grad()
```
"""
dtype: torch.dtype

@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
"""Called before backward.

Args:
loss (Tensor): Loss value.

Returns:
Tensor: Loss value (possibly scaled).
"""
pass

@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
"""Called before backward by grad. This is helpful for pipeline parallelism.

Args:
tensor (Tensor): Tensor to backward.
grad (Tensor): Gradient of the tensor.

Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
pass

@abstractmethod
def should_skip_step(self) -> bool:
"""Called before step.

Returns:
bool: Whether to skip the step.
"""
pass

@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad.
"""
pass

@abstractmethod
def get_grad_div_scale(self) -> float:
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.

Returns:
float: A divisor for gradient clipping or step.
"""
pass
23 changes: 23 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from torch import Tensor

from .base import MixedPrecisionMixin


class BF16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.bfloat16

def pre_backward(self, loss: Tensor) -> Tensor:
return loss

def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
return grad

def should_skip_step(self) -> bool:
return False

def pre_zero_grad(self) -> None:
pass

def get_grad_div_scale(self) -> float:
return 1.0
84 changes: 84 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from abc import abstractmethod
from enum import Enum

import torch
import torch.distributed as dist
from torch import Tensor

from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device

from .base import MixedPrecisionMixin


class OptimState(Enum):
SCALED = 0
UNSCALED = 1


class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16

def __init__(self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())

@property
def loss_scale(self) -> float:
return self.grad_scaler.scale.item()

@abstractmethod
def check_local_overflow(self) -> bool:
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.

Returns:
bool: Whether there is overflow in the local process.
"""
pass

def check_overflow(self) -> bool:
# clear previous overflow record
self.found_overflow.fill_(0.0)
if self.check_local_overflow():
self.found_overflow.fill_(1.0)
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
return self.found_overflow.item() > 0

def pre_backward(self, loss: Tensor) -> Tensor:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
return loss

def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
self.optim_state = OptimState.SCALED
return grad

def should_skip_step(self) -> bool:
found_inf = self.check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED
return found_inf

def pre_zero_grad(self) -> None:
pass

def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
self.optim_state = OptimState.UNSCALED
return self.loss_scale
9 changes: 8 additions & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

__all__ = ['GeminiPlugin']

SUPPORTED_PRECISION = ['fp16', 'bf16']
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}


class GeminiCheckpointIO(GeneralCheckpointIO):

Expand Down Expand Up @@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
Expand Down Expand Up @@ -203,6 +207,7 @@ def __init__(
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
Expand All @@ -223,6 +228,7 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
device=(device or get_current_device()),
placement_policy=placement_policy,
Expand All @@ -233,6 +239,7 @@ def __init__(
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
self.optim_kwargs = dict(initial_scale=initial_scale,
Expand All @@ -253,7 +260,7 @@ def control_precision(self) -> bool:
return True

def supported_precisions(self) -> List[str]:
return ['fp16']
return SUPPORTED_PRECISION

def control_device(self) -> bool:
return True
Expand Down
33 changes: 22 additions & 11 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from functools import partial
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand All @@ -20,12 +21,15 @@
__all__ = ['LowLevelZeroPlugin']


def _convert_to_fp16(x):
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.half()
return x.to(dtype)
return x


SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
Expand All @@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper):

def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.convert_inputs = (precision == 'fp16')
module = zero_model_wrapper(module, zero_stage=stage)
self.dtype = None
if precision == 'fp16':
module = module.half()
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)

def forward(self, *args, **kwargs):
if self.convert_inputs:
args = tree_map(_convert_to_fp16, args)
kwargs = tree_map(_convert_to_fp16, kwargs)
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)


Expand Down Expand Up @@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase):

Args:
strage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
Expand Down Expand Up @@ -149,7 +160,7 @@ def __init__(
) -> None:
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'

self.stage = stage
self.precision = precision
Expand All @@ -175,7 +186,7 @@ def control_precision(self) -> bool:
return True

def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']
return SUPPORTED_PRECISION

def control_device(self) -> bool:
return True
Expand Down
15 changes: 15 additions & 0 deletions colossalai/kernel/cuda_native/csrc/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else { \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \
Expand Down
Loading