From dd3657dcb0b3aa05084dee4a24fc4549fe89fafc Mon Sep 17 00:00:00 2001 From: lclgy Date: Tue, 4 Jul 2023 13:54:00 +0800 Subject: [PATCH 1/3] allow passing process group to zero12 --- colossalai/zero/__init__.py | 5 +- colossalai/zero/low_level/__init__.py | 4 +- colossalai/zero/low_level/_utils.py | 48 +++---- colossalai/zero/low_level/low_level_optim.py | 121 +++++++++--------- .../test_low_level/test_zero_init.py | 5 +- .../test_zero/test_low_level/test_zero_tp.py | 14 +- 6 files changed, 97 insertions(+), 100 deletions(-) diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 3465079e4fbb..ce5dabe742b5 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -7,10 +7,11 @@ get_static_torch_model, post_process_colo_init_ctx, ) -from .low_level import LowLevelZeroOptimizer +from .low_level import LowLevelZeroOptimizer, TPLowLevelZeroOptimizer from .wrapper import zero_model_wrapper, zero_optim_wrapper __all__ = [ 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', - 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' + 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model', + 'TPLowLevelZeroOptimizer' ] diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index ae3c1de3a5bc..ace04c32ccb0 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,3 +1,3 @@ -from .low_level_optim import LowLevelZeroOptimizer +from .low_level_optim import LowLevelZeroOptimizer, TPLowLevelZeroOptimizer -__all__ = ['LowLevelZeroOptimizer'] +__all__ = ['LowLevelZeroOptimizer', 'TPLowLevelZeroOptimizer'] diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index a9e552ebdabc..4205a9891534 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -3,8 +3,9 @@ import torch import torch.distributed as dist -from torch import inf +from torch import Tensor, inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import ProcessGroup from colossalai.tensor import ColoParameter from colossalai.utils import is_model_parallel_parameter @@ -194,26 +195,21 @@ def calculate_global_norm_from_list(norm_list): return math.sqrt(total_norm) -def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): +def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int: """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. + added functionality to handle model parallel parameters. + + Args: + gradients (Tensor): The gradients to compute norm + dp_group (ProcessGroup): The process group of ZeRO Data Parallelism + tp_group (ProcessGroup): The process group of Tensor Parallelism + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + Returns: - Total norm of the parameters (viewed as a single vector). + int: The total norm of given gradients """ - if mp_group is None: - mp_rank = 0 - else: - mp_rank = dist.get_rank(mp_group) - norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) @@ -221,29 +217,21 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) # Take max across all GPUs. - if mp_group is not None: + if tp_group is not None: dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - - for g, p in zip(gradients, params): - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - tp_param_flag = False - if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()): - tp_param_flag = True - if tp_param_flag or mp_rank == 0: - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + for g in gradients: + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) - if mp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group) + if tp_group is not None: + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) total_norm = total_norm_cuda[0].item()**(1. / norm_type) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 615c870971b1..fd3df052b4a4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import ( @@ -12,12 +13,9 @@ FP16MixedPrecisionMixin, MixedPrecisionMixin, ) -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor import ColoParameter, ProcessGroup +# from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device from ._utils import ( @@ -56,8 +54,8 @@ def check_local_overflow(self) -> bool: return False -class LowLevelZeroOptimizer(OptimizerWrapper): - """Optimizer used for ZeRO-1 and ZeRO-2. +class TPLowLevelZeroOptimizer(OptimizerWrapper): + """Optimizer used for ZeRO-1 and ZeRO-2 with Tensor Parallelism. """ def __init__( @@ -77,13 +75,15 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload + dp_process_group: ProcessGroup = None, + tp_process_group: ProcessGroup = None, forced_dtype: Optional[torch.dtype] = None): # TODO: # 1. process group api # 2. checkpoint IO - super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + super(TPLowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -96,30 +96,12 @@ def __init__( # grad accumulation self.require_grad_sync = True - colo_pg = self._search_colo_process_group() - if isinstance(colo_pg, ProcessGroup): - self._local_rank = colo_pg.dp_local_rank() - self._world_size = colo_pg.dp_world_size() - self._dp_global_ranks = colo_pg.get_ranks_in_dp() - self._dp_torch_group = colo_pg.dp_process_group() - self._mp_torch_group = None - if colo_pg.tp_world_size() > 1: - self._mp_torch_group = colo_pg.tp_process_group() - elif colo_pg is None: - dp_parallel_mode = ParallelMode.DATA - mp_parallel_mode = ParallelMode.MODEL - - self._dp_parallel_mode = dp_parallel_mode - self._mp_parallel_mode = mp_parallel_mode - self._local_rank = gpc.get_local_rank(dp_parallel_mode) - self._world_size = gpc.get_world_size(dp_parallel_mode) - self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode) - self._dp_torch_group = gpc.get_group(dp_parallel_mode) - self._mp_torch_group = None - if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: - self._mp_torch_group = gpc.get_group(mp_parallel_mode) - else: - raise NotImplementedError + # if process_group is none, will use the default one + self.dp_pg = dp_process_group + self._local_rank = dist.get_rank(group=self.dp_pg) + self._world_size = dist.get_world_size(group=self.dp_pg) + + self.tp_pg = tp_process_group # working and master params for mixed precision training self._working_param_groups = dict() @@ -145,9 +127,9 @@ def __init__( # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(self._dp_torch_group) - self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad) - self._bucket_store = BucketStore(self._dp_torch_group) + self._param_store = ParameterStore(self.dp_pg) + self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) + self._bucket_store = BucketStore(self.dp_pg) # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -212,22 +194,6 @@ def _sanity_checks(self): assert param.dtype == self._dtype, \ f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" - def _search_colo_process_group(self): - colo_flag = False - colo_pg = None - for param_group in self.optim.param_groups: - group_params = param_group['params'] - for param in group_params: - if isinstance(param, ColoParameter): - colo_flag = True - if colo_pg is None: - colo_pg = param.get_process_group() - else: - assert colo_pg == param.get_process_group(), "All parameters should be in a same process group" - elif colo_flag: - raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.") - return colo_pg - def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] @@ -291,7 +257,7 @@ def _run_reduction(self): flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - dist.all_reduce(flat_grads, group=self._dp_torch_group) + dist.all_reduce(flat_grads, group=self.dp_pg) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -307,7 +273,7 @@ def _run_reduction(self): else: flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) if recieved_grad.dtype != grad_dtype: recieved_grad = recieved_grad.to(grad_dtype) @@ -425,10 +391,7 @@ def step(self, closure=None): # compute norm working_grads = self._grad_store.get_working_grads_by_group_id(group_id) - norm_group = compute_norm(gradients=working_grads, - params=real_working_params[group_id], - dp_group=self._dp_torch_group, - mp_group=self._mp_torch_group) + norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg) norm_groups.append(norm_group) self._grad_store.reset_grads_by_group_id(group_id) @@ -454,7 +417,7 @@ def step(self, closure=None): for idx, splited_param in enumerate(master_working_param): full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)] - dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group) + dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg) working_param = real_working_params[group_id][idx] full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param) working_param.data.copy_(full_master_param) @@ -506,3 +469,47 @@ def no_sync(self): yield finally: self.require_grad_sync = old_require_grad_sync + + +class LowLevelZeroOptimizer(TPLowLevelZeroOptimizer): + """Optimizer used for ZeRO-1 and ZeRO-2. + """ + + def __init__( + self, + optimizer: Optimizer, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: ProcessGroup = None, + forced_dtype: Optional[torch.dtype] = None): + + super(LowLevelZeroOptimizer, self).__init__( + optimizer=optimizer, + initial_scale=initial_scale, # grad scaler config + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, # grad clipping + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, # communication + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, # stage 2 flag + cpu_offload=cpu_offload, # cpu offload + dp_process_group=dp_process_group, + forced_dtype=forced_dtype) diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py index aeeaff5b5cb9..368ef976ef6e 100644 --- a/tests/test_zero/test_low_level/test_zero_init.py +++ b/tests/test_zero/test_low_level/test_zero_init.py @@ -33,10 +33,9 @@ def exam_zero_init(): assert optimizer1._local_rank == optimizer2._local_rank assert optimizer1._world_size == optimizer2._world_size - assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks - mp_group1 = optimizer1._mp_torch_group - mp_group2 = optimizer2._mp_torch_group + mp_group1 = optimizer1.tp_pg + mp_group2 = optimizer2.tp_pg assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2) assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2) diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py index f0804f4bb5ba..df299b50178b 100644 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ b/tests/test_zero/test_low_level/test_zero_tp.py @@ -8,7 +8,7 @@ from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer +from colossalai.zero import ColoInitContext, TPLowLevelZeroOptimizer from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal @@ -53,11 +53,13 @@ def exam_zero_with_tp(overlap_flag, partition_flag): torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11 hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2) - hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, - initial_scale=2, - clip_grad_norm=1.0, - overlap_communication=overlap_flag, - partition_grad=partition_flag) + hybrid_optim = TPLowLevelZeroOptimizer(hybrid_optim, + initial_scale=2, + clip_grad_norm=1.0, + overlap_communication=overlap_flag, + partition_grad=partition_flag, + dp_process_group=tp_pg.dp_process_group(), + tp_process_group=tp_pg.tp_process_group()) dp_local_rank = tp_pg.dp_local_rank() set_seed(255 + dp_local_rank) From e448a4d59b81328f56a45e03b43f368edcf13f32 Mon Sep 17 00:00:00 2001 From: lclgy Date: Tue, 4 Jul 2023 15:44:11 +0800 Subject: [PATCH 2/3] union tp-zero and normal-zero --- colossalai/zero/__init__.py | 5 +- colossalai/zero/low_level/__init__.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 57 ++----------------- .../test_zero/test_low_level/test_zero_tp.py | 16 +++--- 4 files changed, 18 insertions(+), 64 deletions(-) diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index ce5dabe742b5..3465079e4fbb 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -7,11 +7,10 @@ get_static_torch_model, post_process_colo_init_ctx, ) -from .low_level import LowLevelZeroOptimizer, TPLowLevelZeroOptimizer +from .low_level import LowLevelZeroOptimizer from .wrapper import zero_model_wrapper, zero_optim_wrapper __all__ = [ 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', - 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model', - 'TPLowLevelZeroOptimizer' + 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' ] diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index ace04c32ccb0..ae3c1de3a5bc 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,3 +1,3 @@ -from .low_level_optim import LowLevelZeroOptimizer, TPLowLevelZeroOptimizer +from .low_level_optim import LowLevelZeroOptimizer -__all__ = ['LowLevelZeroOptimizer', 'TPLowLevelZeroOptimizer'] +__all__ = ['LowLevelZeroOptimizer'] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index fd3df052b4a4..0b446fb97123 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -54,8 +54,8 @@ def check_local_overflow(self) -> bool: return False -class TPLowLevelZeroOptimizer(OptimizerWrapper): - """Optimizer used for ZeRO-1 and ZeRO-2 with Tensor Parallelism. +class LowLevelZeroOptimizer(OptimizerWrapper): + """Optimizer used for ZeRO-1 and ZeRO-2. """ def __init__( @@ -75,15 +75,14 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - dp_process_group: ProcessGroup = None, - tp_process_group: ProcessGroup = None, + dp_process_group: ProcessGroup = None, # the dp pg for comm + tp_process_group: ProcessGroup = None, # if using tp forced_dtype: Optional[torch.dtype] = None): # TODO: - # 1. process group api - # 2. checkpoint IO + # 1. state_dict for checkpoint IO - super(TPLowLevelZeroOptimizer, self).__init__(optim=optimizer) + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -469,47 +468,3 @@ def no_sync(self): yield finally: self.require_grad_sync = old_require_grad_sync - - -class LowLevelZeroOptimizer(TPLowLevelZeroOptimizer): - """Optimizer used for ZeRO-1 and ZeRO-2. - """ - - def __init__( - self, - optimizer: Optimizer, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: ProcessGroup = None, - forced_dtype: Optional[torch.dtype] = None): - - super(LowLevelZeroOptimizer, self).__init__( - optimizer=optimizer, - initial_scale=initial_scale, # grad scaler config - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - clip_grad_norm=clip_grad_norm, # grad clipping - verbose=verbose, - reduce_bucket_size=reduce_bucket_size, # communication - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, # stage 2 flag - cpu_offload=cpu_offload, # cpu offload - dp_process_group=dp_process_group, - forced_dtype=forced_dtype) diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py index df299b50178b..238de3334c80 100644 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ b/tests/test_zero/test_low_level/test_zero_tp.py @@ -8,7 +8,7 @@ from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, TPLowLevelZeroOptimizer +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal @@ -53,13 +53,13 @@ def exam_zero_with_tp(overlap_flag, partition_flag): torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11 hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2) - hybrid_optim = TPLowLevelZeroOptimizer(hybrid_optim, - initial_scale=2, - clip_grad_norm=1.0, - overlap_communication=overlap_flag, - partition_grad=partition_flag, - dp_process_group=tp_pg.dp_process_group(), - tp_process_group=tp_pg.tp_process_group()) + hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, + initial_scale=2, + clip_grad_norm=1.0, + overlap_communication=overlap_flag, + partition_grad=partition_flag, + dp_process_group=tp_pg.dp_process_group(), + tp_process_group=tp_pg.tp_process_group()) dp_local_rank = tp_pg.dp_local_rank() set_seed(255 + dp_local_rank) From 5ac28573299795fd94825c5b6a2596cfc3034276 Mon Sep 17 00:00:00 2001 From: lclgy Date: Tue, 4 Jul 2023 16:09:11 +0800 Subject: [PATCH 3/3] polish code --- colossalai/zero/low_level/low_level_optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 0b446fb97123..27ac06ec9dc5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -75,8 +75,8 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - dp_process_group: ProcessGroup = None, # the dp pg for comm - tp_process_group: ProcessGroup = None, # if using tp + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): # TODO: