diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9cc58fdbac01..ee97b6278d9e 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -732,10 +732,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): # move parameters to flattened buffer if not self.offload_param: # partitioned params remain in GPU during training # move parameter partitions into a single contiguous flat buffer - parameter_partitions: List[Tensor] = [] - for sub_group in self.fp16_groups: - for param in sub_group: - parameter_partitions.append(param.ds_tensor) + parameter_partitions = self._get_parameter_partitions() # We need to keep the reference to this buffer to make sure you can free it in `offload_states` self.lp_param_buffer = __class__.defragment(parameter_partitions) @@ -786,6 +783,9 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(largest_partition_numel) + def _get_parameter_partitions(self) -> List[Tensor]: + return [param.ds_tensor for sub_group in self.fp16_groups for param in sub_group] + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): offset = 0 elements_in_sub_group = sum([t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) @@ -2954,8 +2954,8 @@ def reload_states(self, non_blocking: bool = False): self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) self._set_fp16_partitioned_groups_flat() - for tensor, offset, tensor_numel in get_mapping_to_flat_buffer( - [p.ds_tensor for p in self.module.parameters()]): + parameter_partitions = self._get_parameter_partitions() + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(parameter_partitions): tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) diff --git a/tests/unit/runtime/zero/test_offload_states.py b/tests/unit/runtime/zero/test_offload_states.py index 9105a54661fa..44bff480e27b 100644 --- a/tests/unit/runtime/zero/test_offload_states.py +++ b/tests/unit/runtime/zero/test_offload_states.py @@ -33,11 +33,11 @@ def compare_device(state) -> bool: assert compare_device(state), f"State {state} is not on device {device}" -def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking): +def run_model(model, param_groups, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking): # Currently we only support OffloadDeviceEnum.cpu offload_device = OffloadDeviceEnum.cpu - model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=param_groups, config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, @@ -124,5 +124,12 @@ def test_offload_states(self, included_state, pin_memory, non_blocking): with deepspeed.zero.Init(config_dict_or_path=config_dict): model = SimpleModel(hidden_dim, nlayers=4) + param_groups = [{ + "params": [p for n, p in model.named_parameters() if not 'bias' in n], + "weight_decay": 0.1 + }, { + "params": [p for n, p in model.named_parameters() if 'bias' in n], + "weight_decay": 0.0 + }] include = None if included_state is None else [included_state] - run_model(model, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking) + run_model(model, param_groups, config_dict, hidden_dim, torch.bfloat16, include, pin_memory, non_blocking)