diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py new file mode 100644 index 000000000000..b0348e1477bb --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py @@ -0,0 +1,9 @@ +from .base import MixedPrecisionMixin +from .bf16 import BF16MixedPrecisionMixin +from .fp16 import FP16MixedPrecisionMixin + +__all__ = [ + 'MixedPrecisionMixin', + 'FP16MixedPrecisionMixin', + 'BF16MixedPrecisionMixin', +] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py new file mode 100644 index 000000000000..a52a9747ad1e --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod + +import torch +from torch import Tensor + + +class MixedPrecisionMixin(ABC): + """A helper class for mixed precision training. This mixin is used in mixed precision optimizers. + + Attributes: + dtype (torc.dtype): The expected dtype of the gradients. + + Examples: + ```python + class MyMixedPrecisionOptimizer(OptimizerWrapper): + def __init__(self, optim: Optimizer): + super().__init__(optim) + self.mixed_precision = MixedPrecisionMixin() + + def backward(self, loss): + loss = self.mixed_precision.pre_backward(loss) + loss.backward() + + def backward_by_grad(self, tensor, grad): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def step(self): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + div_scale = self.mixed_precision.get_grad_div_scale() + # maybe clip grad here + # maybe scale grad here + self.optim.step() + + def zero_grad(self): + self.mixed_precision.pre_zero_grad() + return self.optim.zero_grad() + ``` + """ + dtype: torch.dtype + + @abstractmethod + def pre_backward(self, loss: Tensor) -> Tensor: + """Called before backward. + + Args: + loss (Tensor): Loss value. + + Returns: + Tensor: Loss value (possibly scaled). + """ + pass + + @abstractmethod + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + """Called before backward by grad. This is helpful for pipeline parallelism. + + Args: + tensor (Tensor): Tensor to backward. + grad (Tensor): Gradient of the tensor. + + Returns: + Tensor: Gradient of the tensor (possibly scaled). + """ + pass + + @abstractmethod + def should_skip_step(self) -> bool: + """Called before step. + + Returns: + bool: Whether to skip the step. + """ + pass + + @abstractmethod + def pre_zero_grad(self) -> None: + """Called before zero_grad. + """ + pass + + @abstractmethod + def get_grad_div_scale(self) -> float: + """Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads. + + Returns: + float: A divisor for gradient clipping or step. + """ + pass diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py new file mode 100644 index 000000000000..9454f6eb8413 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +from .base import MixedPrecisionMixin + + +class BF16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.bfloat16 + + def pre_backward(self, loss: Tensor) -> Tensor: + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + return grad + + def should_skip_step(self) -> bool: + return False + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + return 1.0 diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py new file mode 100644 index 000000000000..1ce8e42eb3ed --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -0,0 +1,84 @@ +from abc import abstractmethod +from enum import Enum + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.utils import get_current_device + +from .base import MixedPrecisionMixin + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class FP16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.float16 + + def __init__(self, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__() + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self.optim_state = OptimState.UNSCALED + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + + @property + def loss_scale(self) -> float: + return self.grad_scaler.scale.item() + + @abstractmethod + def check_local_overflow(self) -> bool: + """Check whether there is overflow in the local process. This method should be implemented by subclasses. + + Returns: + bool: Whether there is overflow in the local process. + """ + pass + + def check_overflow(self) -> bool: + # clear previous overflow record + self.found_overflow.fill_(0.0) + if self.check_local_overflow(): + self.found_overflow.fill_(1.0) + dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX) + return self.found_overflow.item() > 0 + + def pre_backward(self, loss: Tensor) -> Tensor: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + self.optim_state = OptimState.SCALED + return grad + + def should_skip_step(self) -> bool: + found_inf = self.check_overflow() + self.grad_scaler.update(found_inf) + if found_inf: + self.optim_state = OptimState.UNSCALED + return found_inf + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping' + self.optim_state = OptimState.UNSCALED + return self.loss_scale diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 3e7661ecab76..d4d03e5b5fcd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -6,7 +6,11 @@ import torch.distributed as dist from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import ( + BF16MixedPrecisionMixin, + FP16MixedPrecisionMixin, + MixedPrecisionMixin, +) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger @@ -27,6 +31,31 @@ from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + num_working_param_groups: int, + grad_store: GradientStore, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.num_working_param_groups = num_working_param_groups + self.grad_store = grad_store + + def check_local_overflow(self) -> bool: + for group_id in range(self.num_working_param_groups): + for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True + return False + + class LowLevelZeroOptimizer(ColossalaiOptimizer): """Optimizer used for ZeRO-1 and ZeRO-2. """ @@ -100,17 +129,6 @@ def __init__( self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype - # gradient scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - verbose=verbose) - self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) - # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -200,14 +218,25 @@ def __init__( if self._overlap_communication or self._partition_grads: self._attach_reduction_hook() + # initialize mixed precision mixin + self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None + if self._dtype is torch.float16: + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups, + self._grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif self._dtype is torch.bfloat16: + self.mixed_precision_mixin = BF16MixedPrecisionMixin() + @property def dtype(self): return self._dtype - @property - def loss_scale(self): - return self.grad_scaler.scale - @property def num_param_groups(self): return len(self._working_param_groups) @@ -392,7 +421,8 @@ def _add_to_reduction_bucket(self, param, reduce_rank=None): ################################ def backward(self, loss, retain_graph=False, sync_grad=True): - loss = self.loss_scale * loss + if self.mixed_precision_mixin is not None: + loss = self.mixed_precision_mixin.pre_backward(loss) loss.backward(retain_graph=retain_graph) # finish gradient reduction @@ -419,6 +449,8 @@ def zero_grad(self, set_to_none=True): :param set_to_none: Whether set the gradient to None. Default value is True. :type set_to_none: bool """ + if self.mixed_precision_mixin is not None: + self.mixed_precision_mixin.pre_zero_grad() for _, param_group in self._working_param_groups.items(): for param in param_group: if set_to_none: @@ -435,12 +467,7 @@ def zero_grad(self, set_to_none=True): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - # check for overflow - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) - - # update loss scale if overflow occurs - if found_inf: + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): self._grad_store.reset_all_average_gradients() if self._verbose: self._logger.info(f'Found overflow. Skip step') @@ -507,41 +534,20 @@ def step(self, closure=None): # Mixed Precision Utilities # ############################# - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(0.0) - - # check for overflow - for group_id in range(len(self._working_param_groups)): - for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): - if avg_grad is not None and has_inf_or_nan(avg_grad): - self._found_overflow.fill_(1.0) - break - - # all-reduce across dp group - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group) - - # all-reduce over model parallel group - if self._mp_torch_group: - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group) - - if self._found_overflow.item() > 0: - return True - else: - return False - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group - combined_scale = self.loss_scale + div_scale = 1.0 + if self.mixed_precision_mixin is not None: + div_scale = self.mixed_precision_mixin.get_grad_div_scale() if self._clip_grad_norm > 0.: # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm if clip > 1: - combined_scale = clip * self.loss_scale + div_scale = clip * div_scale for grad in grad_groups_flat: - grad.data.mul_(1. / combined_scale) + grad.data.mul_(1. / div_scale) ############################ # Gradient Synchronization # diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 2ae1f3a99d79..c264a8077d2a 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -82,7 +82,6 @@ def fwd_bwd_func(number, cur_data): def exam_zero_1_grad_acc(): local_rank = torch.distributed.get_rank() - grad_scale = 32 seed_all(2008) # create models @@ -101,7 +100,6 @@ def exam_zero_1_grad_acc(): # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=False, - initial_scale=grad_scale, reduce_bucket_size=262144, clip_grad_norm=1.0) @@ -128,9 +126,8 @@ def fwd_bwd_func(number, cur_data, check_flag): if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - unscale_grad = z1p.grad / grad_scale # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) - assert torch.equal(p.grad, unscale_grad) + assert torch.equal(p.grad, z1p.grad) zero_optimizer._sync_grad() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 4086af9d896e..8e2206fe6c8d 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -7,7 +7,7 @@ from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -25,15 +25,18 @@ def forward(self, x): return x -def half_close(a, b, loose=False): +def loose_close(a, b, dtype: torch.dtype = torch.float32): rtol = None atol = None - if loose: + if dtype is torch.float16: rtol = 5e-2 atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 - a = a.detach().half() - b = b.detach().half() + a = a.detach().to(dtype) + b = b.detach().to(dtype) assert_close(a, b, rtol=rtol, atol=atol) @@ -96,7 +99,8 @@ def exam_zero_1_2(): assert torch.equal(z1p.data, z2p.data) -def exam_zero_1_torch_ddp(): +@parameterize('dtype', [torch.float16, torch.bfloat16]) +def exam_zero_1_torch_ddp(dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp(): seed_all(1453) # create models - zero_model = MlpModel() - torch_model = copy.deepcopy(zero_model) + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model).to(dtype) - zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) - torch_model = torch_model.cuda() - - # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # half_close(p.data, z1p.data) + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) @@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp(): input_data = torch.rand(32, 128).cuda() # zero-dp forward - zero_output = zero_model(input_data.half()) + zero_output = zero_model(input_data.to(dtype)) # torch-ddp forward torch_output = torch_model(input_data) - half_close(zero_output, torch_output, loose=True) + loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) @@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp(): # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - half_close(p.grad, z1p.grad, loose=True) + loose_close(p.grad, z1p.grad, dtype=dtype) # zero-dp step zero_optimizer._sync_grad() @@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp(): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # print(n, torch.max(torch.abs(p.data - z1p.data))) - half_close(p.data, z1p.data, loose=True) + loose_close(p.data, z1p.data, dtype=dtype) def run_dist(rank, world_size, port):