From 758d91111d02a99c89340cd78ad2aebf15dc1dd8 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 27 Sep 2023 09:37:58 +0800 Subject: [PATCH 01/13] [feature] support no master weights for low level zero plugin --- .../booster/plugin/low_level_zero_plugin.py | 6 ++++-- colossalai/zero/low_level/low_level_optim.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 457c720f6418..a395249f23cf 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -261,6 +261,7 @@ def __init__( communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, cpu_offload: bool = False, + master_weights: bool = True, verbose: bool = False, ) -> None: super().__init__() @@ -271,18 +272,19 @@ def __init__( self.precision = precision self.zero_optim_kwargs = dict( initial_scale=initial_scale, + min_scale=min_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, - min_scale=min_scale, max_scale=max_scale, clip_grad_norm=max_norm, reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, - cpu_offload=cpu_offload, partition_grad=(stage == 2), + cpu_offload=cpu_offload, + master_weights=master_weights, ) self.verbose = verbose diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 72df93ace302..1f0dccb72203 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -79,6 +79,7 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload + master_weights: bool = True, # master weights dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, @@ -115,6 +116,9 @@ def __init__( # gradient clipping self._clip_grad_norm = clip_grad_norm + # master weights copy + self._master_weights = master_weights + if forced_dtype: for group in self.optim.param_groups: group_params = group["params"] @@ -213,7 +217,11 @@ def _create_master_param_current_rank(self, param_list): padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + else: + splited_param_current_rank = splited_params[self._local_rank].detach().to(self._dtype).to(device) params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -424,6 +432,7 @@ def step(self, closure=None): # it is not 'really' working, e.g. the droped layer # else the splited grad should be attached to the splited param grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + if len(grads) > 0: real_working_params[group_id].append(working_param) grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) @@ -458,13 +467,13 @@ def step(self, closure=None): for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): - working_param = real_working_params[group_id][idx] + working_param = real_working_params[group_id][idx] all_splited_param = [ torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) ] dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] ############################# From 7d0ecfb4d343fb60f507fdbdc0a181e4bb9ac1d3 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 27 Sep 2023 15:21:55 +0800 Subject: [PATCH 02/13] [feature] support no master weights for low level zero plugin, remove data copy when no master weights --- colossalai/zero/low_level/low_level_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 1f0dccb72203..ed9611234758 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -221,7 +221,7 @@ def _create_master_param_current_rank(self, param_list): if self._master_weights is True: splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) else: - splited_param_current_rank = splited_params[self._local_rank].detach().to(self._dtype).to(device) + splited_param_current_rank = splited_params[self._local_rank].to(self._dtype).to(device) params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) From c4cde46944446e84d0da8fd1f20ff38783bcdb05 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 27 Sep 2023 15:28:50 +0800 Subject: [PATCH 03/13] remove data copy and typecasting when no master weights --- colossalai/zero/low_level/low_level_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index ed9611234758..c45797ecc541 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -221,7 +221,7 @@ def _create_master_param_current_rank(self, param_list): if self._master_weights is True: splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) else: - splited_param_current_rank = splited_params[self._local_rank].to(self._dtype).to(device) + splited_param_current_rank = splited_params[self._local_rank].to(device) params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) From 3c414ec477a22fbd19c3af3b40e01b7c988e4132 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 27 Sep 2023 17:56:21 +0800 Subject: [PATCH 04/13] not load weights to cpu when using no master weights --- colossalai/zero/low_level/low_level_optim.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index c45797ecc541..c628597d244b 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -216,15 +216,18 @@ def _create_master_param_current_rank(self, param_list): else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) - + # use fp32 when master_weights is True if self._master_weights is True: - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) else: - splited_param_current_rank = splited_params[self._local_rank].to(device) + splited_param_current_rank = splited_params[self._local_rank] + params_current_rank.append(splited_param_current_rank) + # should also link the splited_param to param when master_weights is False + # or the grad cannot be found in step() method self._param_store.link_master_and_working_param(splited_param_current_rank, param) - + return params_current_rank ########################### @@ -432,7 +435,6 @@ def step(self, closure=None): # it is not 'really' working, e.g. the droped layer # else the splited grad should be attached to the splited param grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) - if len(grads) > 0: real_working_params[group_id].append(working_param) grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) @@ -467,13 +469,12 @@ def step(self, closure=None): for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): - working_param = real_working_params[group_id][idx] + working_param = real_working_params[group_id][idx] all_splited_param = [ torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) ] dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] ############################# From ecaadd79aa31134490c4484dd0e179246580b940 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 11 Oct 2023 12:17:39 +0800 Subject: [PATCH 05/13] fix grad: use fp16 grad when no master weights --- colossalai/zero/low_level/low_level_optim.py | 26 +++++++++----------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index c628597d244b..e088a7b0ece4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -31,7 +31,6 @@ ) from .bookkeeping import BucketStore, GradientStore, ParameterStore - class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__( self, @@ -148,7 +147,6 @@ def __init__( self._working_param_groups[group_id] = group_params master_param_current_rank = self._create_master_param_current_rank(group_params) - self._master_param_groups_of_current_rank[group_id] = master_param_current_rank # need to replace the params in the `params` field in the optimizer @@ -216,16 +214,13 @@ def _create_master_param_current_rank(self, param_list): else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) - + # use fp32 when master_weights is True if self._master_weights is True: - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) else: - splited_param_current_rank = splited_params[self._local_rank] - + splited_param_current_rank = splited_params[self._local_rank] params_current_rank.append(splited_param_current_rank) - # should also link the splited_param to param when master_weights is False - # or the grad cannot be found in step() method self._param_store.link_master_and_working_param(splited_param_current_rank, param) return params_current_rank @@ -422,9 +417,7 @@ def step(self, closure=None): # and should not be updated real_working_params = dict() real_master_params = dict() - grad_index = 0 if self._partition_grads else self._local_rank - for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] real_working_params[group_id] = [] @@ -437,9 +430,14 @@ def step(self, closure=None): grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: real_working_params[group_id].append(working_param) - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + else: + splited_param.grad = grads[grad_index] + grad_partition_groups.append(grads[grad_index]) real_master_params[group_id].append(splited_param) # compute norm @@ -458,7 +456,7 @@ def step(self, closure=None): # update the parameters self.optim.step() - + # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): From 14b4be0579aabbda8730360c9e503306e8880048 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 11 Oct 2023 12:29:12 +0800 Subject: [PATCH 06/13] fix code complexity --- colossalai/zero/low_level/low_level_optim.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e088a7b0ece4..5c5821bf7d2d 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -431,13 +431,9 @@ def step(self, closure=None): if len(grads) > 0: real_working_params[group_id].append(working_param) # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - else: - splited_param.grad = grads[grad_index] - grad_partition_groups.append(grads[grad_index]) + grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) if self._master_weights else grads[grad_index] + splited_param.grad = grad + grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) # compute norm From ec341f20e94eb3851c311968ffd139a0224b75f2 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 11 Oct 2023 12:39:50 +0800 Subject: [PATCH 07/13] retry fix code complexity --- colossalai/zero/low_level/low_level_optim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 5c5821bf7d2d..1059e7c0d1a1 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -395,7 +395,7 @@ def zero_grad(self, set_to_none=True): #################### # Update Parameter # #################### - + def step(self, closure=None): assert closure is None, "closure is not supported by step()" if not self.require_grad_sync: @@ -459,15 +459,15 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - dtype = real_working_params[0][0].dtype + # dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] From 53a56096e4187971d6b71c9856bf33bff2223afe Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Thu, 12 Oct 2023 10:16:53 +0800 Subject: [PATCH 08/13] do not update working param --- colossalai/zero/low_level/low_level_optim.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 1059e7c0d1a1..27c135d93a7d 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -31,6 +31,7 @@ ) from .bookkeeping import BucketStore, GradientStore, ParameterStore + class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__( self, @@ -211,6 +212,7 @@ def _create_master_param_current_rank(self, param_list): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) @@ -222,7 +224,7 @@ def _create_master_param_current_rank(self, param_list): splited_param_current_rank = splited_params[self._local_rank] params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) - + return params_current_rank ########################### @@ -395,7 +397,7 @@ def zero_grad(self, set_to_none=True): #################### # Update Parameter # #################### - + def step(self, closure=None): assert closure is None, "closure is not supported by step()" if not self.require_grad_sync: @@ -431,7 +433,11 @@ def step(self, closure=None): if len(grads) > 0: real_working_params[group_id].append(working_param) # no need to copy fp32 grad if master_weights is False - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) if self._master_weights else grads[grad_index] + grad = ( + grads[grad_index].to(splited_param.dtype).to(splited_param.device) + if self._master_weights + else grads[grad_index] + ) splited_param.grad = grad grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) @@ -452,7 +458,7 @@ def step(self, closure=None): # update the parameters self.optim.step() - + # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): From 90b2426e41082bfba82c0fefefb31c54e5b4a5e0 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Thu, 12 Oct 2023 13:27:50 +0800 Subject: [PATCH 09/13] only do not update working param when no master weights --- colossalai/zero/low_level/low_level_optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 27c135d93a7d..befb600b9e57 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -212,7 +212,8 @@ def _create_master_param_current_rank(self, param_list): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - param.data = padding_param[: param.numel()].view(param.shape) + if self._master_weights is True: + param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) From 023c13e0a1a6ccd54155288f5ea46718eda4ab6d Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Thu, 12 Oct 2023 13:53:42 +0800 Subject: [PATCH 10/13] fix: only do not update working param when no master weights --- colossalai/zero/low_level/low_level_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index befb600b9e57..c1e4e1bbd3e7 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -212,7 +212,7 @@ def _create_master_param_current_rank(self, param_list): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - if self._master_weights is True: + if self._master_weights == False: param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) From 3c82abd55d8dae90c67df59e2d35d486837cc5a7 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Thu, 12 Oct 2023 17:16:23 +0800 Subject: [PATCH 11/13] fix: passing params in dict format in hybrid plugin --- .../booster/plugin/hybrid_parallel_plugin.py | 36 +++++++++---------- colossalai/zero/low_level/low_level_optim.py | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 46930887bf9c..8101ce8b84aa 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -239,24 +239,24 @@ def __init__( if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( - optimizer, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, - clip_grad_norm, - verbose, - reduce_bucket_size, - communication_dtype, - overlap_communication, - partition_grad, - cpu_offload, - dp_process_group, - tp_process_group, - forced_dtype, + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + tp_process_group=tp_process_group, + forced_dtype=forced_dtype, ) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index c1e4e1bbd3e7..2cf2bd385e35 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -79,10 +79,10 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - master_weights: bool = True, # master weights dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, + master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]["params"][0].dtype From 0172e731b69183fd1299cd9afc2cfc7f079cc443 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Fri, 13 Oct 2023 10:41:20 +0800 Subject: [PATCH 12/13] fix: remove extra params (tp_process_group) in hybrid_parallel_plugin --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b71dd6abd64f..72c3ec46ae75 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -480,7 +480,6 @@ def __init__( partition_grad=partition_grad, cpu_offload=cpu_offload, dp_process_group=dp_process_group, - tp_process_group=tp_process_group, forced_dtype=forced_dtype, ) From 7a26ae128cba83927197308f4de403aea2546556 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Fri, 13 Oct 2023 10:48:04 +0800 Subject: [PATCH 13/13] add a comment --- colossalai/zero/low_level/low_level_optim.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 81ff37af8aff..e6974a6760ce 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -203,6 +203,7 @@ def _create_master_param_current_rank(self, param_list): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # reset working params' ptr when no master weights if self._master_weights == False: param.data = padding_param[: param.numel()].view(param.shape) else: