From c7106f92d6a7483a98a26432b1a56f6176b6dd79 Mon Sep 17 00:00:00 2001 From: lclgy Date: Mon, 3 Jul 2023 15:14:10 +0800 Subject: [PATCH 1/3] support no sync for zero1 plugin --- colossalai/booster/booster.py | 12 ++++--- colossalai/booster/plugin/gemini_plugin.py | 7 ++-- .../booster/plugin/low_level_zero_plugin.py | 18 ++++++---- colossalai/booster/plugin/plugin_base.py | 2 +- colossalai/booster/plugin/torch_ddp_plugin.py | 2 +- .../booster/plugin/torch_fsdp_plugin.py | 2 +- colossalai/zero/low_level/low_level_optim.py | 30 +++++++--------- .../test_zero/test_low_level/test_grad_acc.py | 35 ++++++++----------- 8 files changed, 51 insertions(+), 57 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index cee547b33b0c..ec3dc7fc143f 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO -from colossalai.interface import ModelWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -153,18 +153,20 @@ def execute_pipeline(self, # return loss or outputs if needed pass - def no_sync(self, model: nn.Module) -> contextmanager: + def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: """Context manager to disable gradient synchronization across DP process groups. + Support torch DDP and Low Level ZeRO-1 for now. Args: - model (nn.Module): The model to be disabled gradient synchronization. + model (nn.Module): The model to be disabled gradient synchronization, for DDP + optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1 Returns: contextmanager: Context to disable gradient synchronization. """ assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' - assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' - return self.plugin.no_sync(model) + assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + return self.plugin.no_sync(model, optimizer) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True): """Load model from checkpoint. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 60b25b2c400c..2f0aae189a8e 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -295,10 +295,7 @@ def configure( if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(model.unwrap(), - optimizer, - self.zero_optim_config, - self.optim_kwargs, + optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler @@ -309,5 +306,5 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 94d722080367..2ed01f65e953 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -179,8 +179,14 @@ def __init__( norm_type=norm_type) self.verbose = verbose + # set class name with stage, for better error message + setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") + def support_no_sync(self) -> bool: - return False + if self.stage == 1: + return True + else: + return False def control_precision(self) -> bool: return True @@ -208,10 +214,7 @@ def configure( if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), - optimizer, - self.zero_optim_config, - self.optim_kwargs, + optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler @@ -222,5 +225,6 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: - raise NotImplementedError + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(optimizer, LowLevelZeroOptimizer) + return optimizer.optim.no_sync() diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index aa78f6827003..fb21e57f41f7 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -61,7 +61,7 @@ def get_checkpoint_io(self) -> CheckpointIO: pass @abstractmethod - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: """ Context manager to disable gradient synchronization. """ diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 71b435155503..f3f779c88e42 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -168,6 +168,6 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' return model.module.no_sync() diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index abfffa9b099e..fb7b5baadd0c 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -177,7 +177,7 @@ def __init__( def support_no_sync(self) -> bool: False - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") def control_precision(self) -> bool: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8743cab3313f..bb852f03988e 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -14,10 +14,10 @@ ) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils import conditional_context from colossalai.utils.cuda import get_current_device from ._utils import ( @@ -56,7 +56,7 @@ def check_local_overflow(self) -> bool: return False -class LowLevelZeroOptimizer(ColossalaiOptimizer): +class LowLevelZeroOptimizer(OptimizerWrapper): """Optimizer used for ZeRO-1 and ZeRO-2. """ @@ -77,11 +77,13 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - grad_accumulate_interval: int = 1, forced_dtype: Optional[torch.dtype] = None): - assert not (partition_grad and grad_accumulate_interval > 1), \ - "gradient accumulation is not compatible with ZeRO-2" + # TODO: + # 1. process group api + # 2. state dict + # 3. gradient accumulation + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() @@ -94,8 +96,6 @@ def __init__( # grad accumulation self.require_grad_sync = True - self._accumulate_intervel = grad_accumulate_interval - self._accumulate_step = 0 colo_pg = self._search_colo_process_group() if isinstance(colo_pg, ProcessGroup): @@ -340,15 +340,15 @@ def _add_to_bucket(self, param, group_id): ################################ def backward(self, loss, retain_graph=False): + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - self._accumulate_step += 1 - no_sync = self._accumulate_step < self._accumulate_intervel - with conditional_context(self.no_sync(), enable=no_sync): - loss.backward(retain_graph=retain_graph) + loss.backward(retain_graph=retain_graph) - if no_sync: + if not self.require_grad_sync: return self._reduce_grad(self._partition_grads) @@ -385,7 +385,7 @@ def zero_grad(self, set_to_none=True): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - if not self._accumulate_step == self._accumulate_intervel: + if not self.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): @@ -393,7 +393,6 @@ def step(self, closure=None): if self._verbose: self._logger.info(f'Found overflow. Skip step') self.zero_grad() - self._accumulate_step -= 1 return # record all grads for unscale and clip @@ -463,9 +462,6 @@ def step(self, closure=None): self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] - # reset accumulate step - self._accumulate_step = 0 - ############################# # Mixed Precision Utilities # ############################# diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index ac1f677f9a0d..a1d14f1d5a9d 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -9,6 +9,7 @@ import colossalai from colossalai.testing import spawn from colossalai.testing.random import seed_all +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc(): overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, - grad_accumulate_interval=2, verbose=True) zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, - clip_grad_norm=1.0, - grad_accumulate_interval=2) + clip_grad_norm=1.0) # create data seed_all(2021 + local_rank) input_data1 = torch.randn(32, 128).cuda() @@ -59,8 +58,11 @@ def fwd_bwd_func(number, cur_data, check_flag): assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.sum().float()) - zero2_optimizer.backward(zero2_output.sum().float()) + no_sync = number == 0 + with conditional_context(zero1_optimizer.no_sync(), no_sync): + zero1_optimizer.backward(zero1_output.sum().float()) + with conditional_context(zero2_optimizer.no_sync(), no_sync): + zero2_optimizer.backward(zero2_output.sum().float()) if check_flag: for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): @@ -101,8 +103,7 @@ def exam_zero_1_grad_acc(): zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, - clip_grad_norm=1.0, - grad_accumulate_interval=2) + clip_grad_norm=1.0) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -112,20 +113,15 @@ def exam_zero_1_grad_acc(): input_data2 = torch.randn(32, 128).cuda() def fwd_bwd_func(number, cur_data, check_flag): - # zero-dp forward - zero_output = zero_model(cur_data) - # torch-ddp forward + no_sync = number == 0 + # zero1 fwd and bwd + with conditional_context(zero_optimizer.no_sync(), no_sync): + zero_output = zero_model(cur_data) + zero_optimizer.backward(zero_output.sum().float()) - # zero-dp backward - zero_optimizer.backward(zero_output.sum().float()) - # torch-ddp backward - if number < 1: - with torch_model.no_sync(): - torch_output = torch_model(cur_data) - assert torch.equal(zero_output, torch_output) - torch_output.sum().backward() - else: + # torch-ddp fwd and bwd + with conditional_context(torch_model.no_sync(), no_sync): torch_output = torch_model(cur_data) assert torch.equal(zero_output, torch_output) torch_output.sum().backward() @@ -133,7 +129,6 @@ def fwd_bwd_func(number, cur_data, check_flag): if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) assert torch.equal(p.grad, z1p.grad) fwd_bwd_func(0, input_data1, True) From 92ee707a4dfb59e067a6cdcb31426d2fda7293ff Mon Sep 17 00:00:00 2001 From: lclgy Date: Mon, 3 Jul 2023 15:23:59 +0800 Subject: [PATCH 2/3] polish --- colossalai/zero/low_level/low_level_optim.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bb852f03988e..615c870971b1 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -81,8 +81,7 @@ def __init__( # TODO: # 1. process group api - # 2. state dict - # 3. gradient accumulation + # 2. checkpoint IO super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype From 72b7e0d98e2b369958256d7d76f7bf8a3776f18e Mon Sep 17 00:00:00 2001 From: lclgy Date: Tue, 4 Jul 2023 10:56:16 +0800 Subject: [PATCH 3/3] polish --- colossalai/booster/plugin/low_level_zero_plugin.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 2ed01f65e953..0a3221b231bc 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -183,10 +183,7 @@ def __init__( setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") def support_no_sync(self) -> bool: - if self.stage == 1: - return True - else: - return False + return self.stage == 1 def control_precision(self) -> bool: return True