From 7eacc91fa767b2f40d9e5de1c35b98c08496a660 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Wed, 11 Jan 2023 11:23:34 +0800 Subject: [PATCH 1/4] [ddp] add is_ddp_ignored [ddp] rename to is_ddp_ignored --- colossalai/gemini/chunk/search_utils.py | 9 ++--- colossalai/gemini/chunk/utils.py | 6 ++-- colossalai/nn/optimizer/zero_optimizer.py | 4 +-- colossalai/nn/parallel/data_parallel.py | 12 +++---- colossalai/utils/__init__.py | 42 ++++++++++++++++++----- colossalai/utils/common.py | 8 +++-- colossalai/zero/utils/gemini_hook.py | 5 +-- 7 files changed, 56 insertions(+), 30 deletions(-) diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index 312d77f1826c..572c3d94531f 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -6,17 +6,14 @@ from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.tensor import ColoParameter - - -def in_ddp(param: nn.Parameter) -> bool: - return not getattr(param, '_ddp_to_ignore', False) +from colossalai.utils import is_ddp_ignored def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: """ Filter those parameters whose size is too large (more than 3x standard deviations) from others. """ - params_size = [p.numel() for p in model.parameters() if in_ddp(p)] + params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)] params_size_arr = np.array(params_size) std = np.std(params_size_arr) @@ -56,7 +53,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int params_dict: Dict[int, List[ColoParameter]] = dict() for param in param_order.generate(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" - if not in_ddp(param): + if is_ddp_ignored(param): continue param_key = param.process_group.dp_world_size() diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py index e9a9f84e7a93..883022fe89b8 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/gemini/chunk/utils.py @@ -6,8 +6,8 @@ import torch.nn as nn from colossalai.gemini.chunk import ChunkManager -from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration -from colossalai.gemini.memory_tracer import MemStats +from colossalai.gemini.chunk.search_utils import search_chunk_configuration +from colossalai.utils import is_ddp_ignored def init_chunk_manager(model: nn.Module, @@ -34,7 +34,7 @@ def init_chunk_manager(model: nn.Module, if filter_exlarge_params: kwargs_dict["filter_exlarge_params"] = filter_exlarge_params - params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)] + params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)] total_size = sum(params_sizes) / 1024**2 dist.barrier() diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 7f9d2fe8fc97..3dd9d1e93b36 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.nn.parallel.data_parallel import ZeroDDP -from colossalai.utils import disposable, get_current_device +from colossalai.utils import disposable, get_current_device, is_ddp_ignored _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} @@ -78,7 +78,7 @@ def __init__(self, if self.clipping_flag: assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" - params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] + params_list = [p for p in module.parameters() if not is_ddp_ignored(p)] for p, fp32_p in zip(params_list, module.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) if chunk_16 not in self.chunk16_set: diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 5e547059a937..649bd920d3b2 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import get_current_device +from colossalai.utils import get_current_device, is_ddp_ignored from colossalai.zero.utils.gemini_hook import GeminiZeROHook from .reducer import Reducer @@ -81,7 +81,7 @@ def __init__(self, self.reducer = Reducer(bucket_cap_mb) self.rebuild_bucket = rebuild_bucket for p in module.parameters(): - if getattr(p, '_ddp_to_ignore', False): + if is_ddp_ignored(p): continue if p.requires_grad: p.register_hook(partial(self.grad_handle, p)) @@ -116,7 +116,7 @@ def backward(self, loss: torch.Tensor): if self.rebuild_bucket: self.reducer.free() for p in self.module.parameters(): - if getattr(p, '_ddp_to_ignore', False): + if is_ddp_ignored(p): continue if p.grad.device.type != "cpu": p.grad = p._saved_grad @@ -232,7 +232,7 @@ def __init__(self, for p in param_order.generate(): assert isinstance(p, ColoParameter) - if getattr(p, '_ddp_to_ignore', False): + if is_ddp_ignored(p): p.data = p.data.half() continue @@ -256,7 +256,7 @@ def __init__(self, self.chunk_manager.close_all_groups() self._cast_buffers() - params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)] + params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)] for p, fp32_p in zip(params_list, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) @@ -303,7 +303,7 @@ def forward(self, *args, **kwargs): def _setup_grads_ptr(self): for p in self.module.parameters(): - if getattr(p, '_ddp_to_ignore', False): + if is_ddp_ignored(p): continue p.grad = None diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 875b5a93ba4f..3f16bd91e5fe 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,22 +1,46 @@ -from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .activation_checkpoint import checkpoint from .checkpointing import load_checkpoint, save_checkpoint -from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, - ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, - is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, - param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, - sync_model_param, disposable) +from .common import ( + clip_grad_norm_fp32, + conditional_context, + copy_tensor_parallel_attributes, + count_zeros_fp32, + disposable, + ensure_path_exists, + free_port, + is_ddp_ignored, + is_dp_rank_0, + is_model_parallel_parameter, + is_no_pp_or_last_stage, + is_tp_rank_0, + is_using_ddp, + is_using_pp, + is_using_sequence, + multi_tensor_applier, + param_is_not_tensor_parallel_duplicate, + print_rank_0, + switch_virtual_pipeline_parallel_rank, + sync_model_param, +) +from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .data_sampler import DataParallelSampler, get_dataloader -from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, - colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) -from .timer import MultiTimer, Timer +from .memory import ( + colo_device_memory_capacity, + colo_device_memory_used, + colo_get_cpu_memory_capacity, + colo_set_cpu_memory_capacity, + colo_set_process_memory_fraction, + report_memory_usage, +) from .tensor_detector import TensorDetector +from .timer import MultiTimer, Timer __all__ = [ 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', + 'is_ddp_ignored', 'is_dp_rank_0', 'is_tp_rank_0', 'is_no_pp_or_last_stage', diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 7575fa292f14..2099883fbdf7 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -126,14 +126,18 @@ def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) +def is_ddp_ignored(p): + return getattr(p, '_ddp_to_ignore', False) + + def _calc_l2_norm(grads): - # we should not + # we should not global fused_optim if fused_optim is None: from colossalai.kernel.op_builder import FusedOptimBuilder fused_optim = FusedOptimBuilder().load() - + norm = 0.0 if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/utils/gemini_hook.py index 35569c7172b3..bddc307a0504 100644 --- a/colossalai/zero/utils/gemini_hook.py +++ b/colossalai/zero/utils/gemini_hook.py @@ -8,6 +8,7 @@ from colossalai.gemini import TensorState from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.tensor.param_op_hook import ColoParamOpHook +from colossalai.utils import is_ddp_ignored class TrainingPhase(Enum): @@ -24,7 +25,7 @@ def __init__(self, gemini_manager: GeminiManager) -> None: self._training_phase = TrainingPhase.FORWARD def pre_op(self, params): - params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] + params = [p for p in params if not is_ddp_ignored(p)] chunks = self._chunk_manager.get_chunks(params) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) @@ -37,7 +38,7 @@ def pre_op(self, params): self._gemini_manager.record_model_data_volume() def post_op(self, params): - params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] + params = [p for p in params if not is_ddp_ignored(p)] for p in params: tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD self._chunk_manager.trans_tensor_state(p, tensor_state) From 16a3af55350e9118ddd5ef9de9892d8183e81bd0 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Wed, 11 Jan 2023 13:12:58 +0800 Subject: [PATCH 2/4] [zero] fix state_dict and load_state_dict --- colossalai/nn/parallel/data_parallel.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 649bd920d3b2..bc64f979228a 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -233,7 +233,7 @@ def __init__(self, assert isinstance(p, ColoParameter) if is_ddp_ignored(p): - p.data = p.data.half() + p.data = p.data.to(device=get_current_device(), dtype=torch.float16) continue fp32_data = p.data.float() @@ -451,8 +451,14 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) - # TODO: (HELSON) deal with ddp ignored parameters - for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + ddp_param_list = [] + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + destination[prefix + name] = param if keep_vars else param.detach() + else: + ddp_param_list.append(param) + for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) record_parameter = param_to_save_data[fp32_p] @@ -588,8 +594,16 @@ def load(param_name, dest_tensor, copy_func): def load_fp32_parameter(chunk_slice, data): chunk_slice.copy_(data.flatten()) + ddp_param_list = [] + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + load(name, param, param.copy_) + else: + ddp_param_list.append(param) + fp32_to_name = dict() - for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): if p is not None: fp32_to_name[fp32_p] = name From d909fffb30707a21506026e80a30921206a49466 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Wed, 11 Jan 2023 13:18:44 +0800 Subject: [PATCH 3/4] fix bugs --- colossalai/nn/parallel/data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index bc64f979228a..28a10c4b6c92 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -457,7 +457,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): # deal with ddp ignored parameters destination[prefix + name] = param if keep_vars else param.detach() else: - ddp_param_list.append(param) + ddp_param_list.append((name, param)) for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) @@ -600,7 +600,7 @@ def load_fp32_parameter(chunk_slice, data): # deal with ddp ignored parameters load(name, param, param.copy_) else: - ddp_param_list.append(param) + ddp_param_list.append((name, param)) fp32_to_name = dict() for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): From e8ce95dcb58b3ca754fa11d2c3c05ee973813f47 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Wed, 11 Jan 2023 13:42:16 +0800 Subject: [PATCH 4/4] [zero] update unit test for ZeroDDP --- tests/test_gemini/update/test_zeroddp_state_dict.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index b902bb0f010e..266b8eab121b 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from torch.testing import assert_close import colossalai from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration @@ -17,6 +18,13 @@ from tests.test_tensor.common_utils import debug_print, set_seed +def ignore_the_first_parameter(model: torch.nn.Module): + for name, param in model.named_parameters(): + print(f"parameter `{name}` is set ignored") + ZeroDDP.set_params_to_ignore([param]) + return + + @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('keep_gathered', [True, False]) @parameterize('model_name', ['gpt2', 'bert']) @@ -47,7 +55,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): for key, value in torch_dict.items(): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @@ -84,7 +92,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): for key, value in torch_dict.items(): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) def run_dist(rank, world_size, port):