diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index adbf4803eefe..46714fe1c679 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -23,6 +23,9 @@ __all__ = ['GeminiPlugin'] +SUPPORTED_PRECISION = ['fp16', 'bf16'] +PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16} + class GeminiCheckpointIO(GeneralCheckpointIO): @@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase): Args: device (torch.device): device to place the model. placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -203,6 +207,7 @@ def __init__( self, device: Optional[torch.device] = None, placement_policy: str = "cpu", + precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, @@ -223,6 +228,7 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() + assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' self.gemini_config = dict( device=(device or get_current_device()), placement_policy=placement_policy, @@ -233,6 +239,7 @@ def __init__( hidden_dim=hidden_dim, min_chunk_size_mb=min_chunk_size_mb, memstats=memstats, + mixed_precision=PRECISION_STR_TO_DTYPE[precision], ) self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) self.optim_kwargs = dict(initial_scale=initial_scale, @@ -253,7 +260,7 @@ def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 5d93cf0e33be..2b312d0f9947 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,4 +1,5 @@ import warnings +from functools import partial from typing import Callable, Iterator, List, Optional, Tuple, Union import torch @@ -20,12 +21,15 @@ __all__ = ['LowLevelZeroPlugin'] -def _convert_to_fp16(x): +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.half() + return x.to(dtype) return x +SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] + + class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): @@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper): def __init__(self, module: nn.Module, stage: int, precision: str) -> None: super().__init__(module) - self.convert_inputs = (precision == 'fp16') - module = zero_model_wrapper(module, zero_stage=stage) + self.dtype = None if precision == 'fp16': - module = module.half() + self.dtype = torch.float16 + elif precision == 'bf16': + self.dtype = torch.bfloat16 + module = zero_model_wrapper(module, zero_stage=stage) + if self.dtype is not None: + module = module.to(self.dtype) module = module.to(get_current_device()) self.module = module + self.convert_fn = None + if self.dtype is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) def forward(self, *args, **kwargs): - if self.convert_inputs: - args = tree_map(_convert_to_fp16, args) - kwargs = tree_map(_convert_to_fp16, kwargs) + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) @@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase): Args: strage (int, optional): ZeRO stage. Defaults to 1. - precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. + precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. @@ -149,7 +160,7 @@ def __init__( ) -> None: super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' - assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' + assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' self.stage = stage self.precision = precision @@ -175,7 +186,7 @@ def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16', 'fp32'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True