diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 878c25be7094..7e230896aac8 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -2,7 +2,7 @@ from collections import OrderedDict from contextlib import nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Union, Tuple, Set +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP): strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. Defaults to False. Users can set it to True, when they clearly know that they only need DDP. scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference. + mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16. """ def __init__(self, @@ -59,7 +60,9 @@ def __init__(self, pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, - scatter_after_inference: bool = True) -> None: + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16) -> None: + assert mixed_precision in (torch.float16, torch.bfloat16) self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 @@ -71,6 +74,7 @@ def __init__(self, self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() self.scatter_after_inference = scatter_after_inference + self.mixed_precision = mixed_precision self._logger = get_dist_logger() @@ -96,34 +100,38 @@ def __init__(self, param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var super().__init__(module, process_group=ColoProcessGroup()) - self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module) + self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() - def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True): - - r""" - Args: - memo: a memo to store the set of modules already added to the result - prefix: a prefix that will be added to the name of the module - remove_duplicate: whether to remove the duplicated module instances in the result - or not - """ - - if memo is None: - memo = set() - self_non_persistent_set = set() - if module not in memo: - if remove_duplicate: - memo.add(module) - self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) - for name, sub_module in module._modules.items(): - if sub_module is None: - continue - submodule_prefix = prefix + ('.' if prefix else '') + name - child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) - self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) - return self_non_persistent_set - + def _get_non_persistent_buffers_set(self, + module, + memo: Optional[Set[nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set( + map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, + remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set def _post_forward(self): """This function is only triggered for inference. @@ -147,7 +155,7 @@ def forward(self, *args, **kwargs): assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( ), "You should run a completed iteration as your warmup iter" - args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision) self.module.zero_grad(set_to_none=True) if not grad_flag: outputs = self._inference_forward(*args, **kwargs) @@ -566,14 +574,14 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) continue # create a fp32 parameter fp32_data = p.data.float() fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) # create a fp16 parameter - p.data = p.data.half() + p.data = p.data.to(self.mixed_precision) # register the fp16 parameter and fp32 parameter in the chunk manager dp_world_size = p.process_group.dp_world_size() @@ -609,7 +617,7 @@ def _cast_buffers(self): buffer.materialize() buffer.data = buffer.cuda() if torch.is_floating_point(buffer): - buffer.data = buffer.half() + buffer.data = buffer.to(self.mixed_precision) def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: """Convert parameter to ColoParameter in-place. @@ -732,6 +740,7 @@ def __init__(self, hidden_dim: Optional[int] = None, min_chunk_size_mb: float = 32, memstats: Optional[MemStats] = None, + mixed_precision: torch.dtype = torch.float16, verbose: bool = False) -> None: """ A torch.Module wrapper using ZeRO-DP and Gemini. @@ -772,5 +781,10 @@ def __init__(self, strict_ddp_flag=strict_ddp_mode, verbose=verbose) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode, - scatter_after_inference) + super().__init__(module, + gemini_manager, + pin_memory, + force_outputs_fp32, + strict_ddp_mode, + scatter_after_inference, + mixed_precision=mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 71c4f65cb8d2..267deb1e8699 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import math import warnings -from enum import Enum from typing import Any, Dict, Set, Tuple import torch @@ -9,7 +8,7 @@ from torch.nn import Parameter from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.utils import disposable, get_current_device, is_ddp_ignored @@ -22,9 +21,26 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} -class OptimState(Enum): - SCALED = 0 - UNSCALED = 1 +class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + module: ZeroDDP, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.module = module + + def check_local_overflow(self) -> bool: + return self.module.overflow_counter > 0 + + def pre_zero_grad(self) -> None: + self.module.overflow_counter = 0 class ZeroOptimizer(ColossalaiOptimizer): @@ -79,7 +95,6 @@ def __init__(self, self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager - self.optim_state = OptimState.UNSCALED self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() @@ -107,15 +122,20 @@ def __init__(self, self.__init__optimizer() - # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + if module.mixed_precision is torch.float16: + self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif module.mixed_precision is torch.bfloat16: + self.mix_precision_mixin = BF16MixedPrecisionMixin() + else: + raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}") + self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) @@ -151,15 +171,6 @@ def _update_fp16_params(self): for chunk16 in self.chunk16_set: chunk16.optim_update() - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(self.module.overflow_counter) - - # all-reduce across global group - dist.all_reduce(self._found_overflow) - - return self._found_overflow.item() > 0 - def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: c16.l2_norm = None @@ -190,40 +201,25 @@ def _calc_global_norm(self) -> float: return global_norm def _get_combined_scale(self): - loss_scale = 1 - - if self.optim_state == OptimState.SCALED: - loss_scale = self.loss_scale - self.optim_state = OptimState.UNSCALED + div_scale = self.mix_precision_mixin.get_grad_div_scale() - combined_scale = loss_scale if self.clipping_flag: total_norm = self._calc_global_norm() - clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: - combined_scale = clip * loss_scale + div_scale = clip * div_scale - if combined_scale == 1: - return -1 - else: - return combined_scale - - @property - def loss_scale(self): - return self.grad_scaler.scale.item() + return -1 if div_scale == 1.0 else div_scale def zero_grad(self, *args, **kwargs): - self.module.overflow_counter = 0 + self.mix_precision_mixin.pre_zero_grad() return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): self._maybe_move_fp32_params() self._set_grad_ptr() - found_inf = self._check_overflow() - if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler + if self.mix_precision_mixin.should_skip_step(): if self.verbose: self._logger.info(f'Found overflow. Skip step') self._clear_global_norm() # clear recorded norm @@ -234,7 +230,6 @@ def step(self, *args, **kwargs): # get combined scale. combined scale = loss scale * clipping norm # so that gradient = gradient / combined scale combined_scale = self._get_combined_scale() - self.grad_scaler.update(found_inf) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() @@ -246,8 +241,7 @@ def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: flo raise NotImplementedError def backward(self, loss: torch.Tensor): - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): @@ -255,7 +249,7 @@ def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + grad = self.mix_precision_mixin.pre_backward_by_grad(grad) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8ce20c16e8f9..66611bcd2419 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -21,23 +21,40 @@ # these models are too small, all parameters in these models are compacted into one chunk EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] +# bfloat16 cannot represent them exactly +BF16_IGNORED_KEYS = [ + 'albert.embeddings.word_embeddings.weight', + 'albert.embeddings.position_embeddings.weight', + 'masked_bias', +] -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): - zero_dict = model.state_dict(only_rank_0=False) + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype): + zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + temp_zero_value = zero_dict[key].to(device=value.device) + if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): + continue + rtol, atol = 1e-3, 4e-3 + if dtype is torch.bfloat16: + rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + assert_close(value.float(), + temp_zero_value.float(), + rtol=rtol, + atol=atol, + msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', TEST_MODELS) -def exam_model_step(placement_policy, model_name: str): +@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str): init_device = None chunk_manager = ChunkManager(config_dict, init_device=init_device) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) @@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1e-4, 1e-5 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss) + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', EXAMPLE_MODELS) -def exam_tiny_example(placement_policy, model_name: str): +@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str): chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) @@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1.5e-6, 2e-5 + if mixed_precision is torch.bfloat16: + rtol, atol = 2e-3, 2e-3 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12 + assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) def run_dist(rank, world_size, port):