From ccb8ed126b8a7b440554bda480af7655233e3619 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 14 Aug 2024 15:19:09 +0800 Subject: [PATCH] [plugin] add cast inputs option for zero --- colossalai/booster/plugin/low_level_zero_plugin.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 66491821c375..e4c386a2257d 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -62,7 +62,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -74,7 +76,7 @@ def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -334,6 +336,7 @@ def __init__( cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -360,6 +363,7 @@ def __init__( ) self.lora_enabled = False self.verbose = verbose + self.cast_inputs = cast_inputs # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -474,7 +478,10 @@ def configure( if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, ) # TODO: Support Galore + ZeRO