Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

__all__ = ['GeminiPlugin']

SUPPORTED_PRECISION = ['fp16', 'bf16']
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}


class GeminiCheckpointIO(GeneralCheckpointIO):

Expand Down Expand Up @@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
Expand Down Expand Up @@ -203,6 +207,7 @@ def __init__(
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
Expand All @@ -223,6 +228,7 @@ def __init__(
verbose: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
device=(device or get_current_device()),
placement_policy=placement_policy,
Expand All @@ -233,6 +239,7 @@ def __init__(
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
self.optim_kwargs = dict(initial_scale=initial_scale,
Expand All @@ -253,7 +260,7 @@ def control_precision(self) -> bool:
return True

def supported_precisions(self) -> List[str]:
return ['fp16']
return SUPPORTED_PRECISION

def control_device(self) -> bool:
return True
Expand Down
33 changes: 22 additions & 11 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from functools import partial
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand All @@ -20,12 +21,15 @@
__all__ = ['LowLevelZeroPlugin']


def _convert_to_fp16(x):
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.half()
return x.to(dtype)
return x


SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
Expand All @@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper):

def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.convert_inputs = (precision == 'fp16')
module = zero_model_wrapper(module, zero_stage=stage)
self.dtype = None
if precision == 'fp16':
module = module.half()
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)

def forward(self, *args, **kwargs):
if self.convert_inputs:
args = tree_map(_convert_to_fp16, args)
kwargs = tree_map(_convert_to_fp16, kwargs)
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)


Expand Down Expand Up @@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase):

Args:
strage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
Expand Down Expand Up @@ -149,7 +160,7 @@ def __init__(
) -> None:
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'

self.stage = stage
self.precision = precision
Expand All @@ -175,7 +186,7 @@ def control_precision(self) -> bool:
return True

def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']
return SUPPORTED_PRECISION

def control_device(self) -> bool:
return True
Expand Down