diff --git a/DeepSpeedExamples b/DeepSpeedExamples index bdf8e59aede8..127372571189 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit bdf8e59aede8c8e0577e8d4d557298ca8515268f +Subproject commit 127372571189ac905c8c92f4fe55a3d85c80324e diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index d9e0e399b150..727c0810290e 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -534,7 +534,7 @@ def __init__(self, config: Union[str, dict], mpu=None): object_pairs_hook=dict_raise_error_on_duplicate_keys) else: raise ValueError( - f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {ds_config}" + f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}" ) try: self.global_rank = torch.distributed.get_rank() @@ -765,7 +765,8 @@ def _do_error_check(self): GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: - assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" + if self.zero_optimization_stage < ZERO_OPTIMIZATION_GRADIENTS: + assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) #if self.zero_config.cpu_offload is True: # assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 87ccf1c99e5a..46f969ab44ec 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -582,7 +582,13 @@ def is_replicated(p): def _configure_distributed_model(self, model): self.module = model if self.fp16_enabled(): + if self.zero_optimization_partition_weights() and any( + [hasattr(param, + 'ds_id') for param in self.module.parameters()]): + assert all([param.dtype == torch.half for param in self.module.parameters()]), f"Model must initialized in fp16 mode for ZeRO Stage 3." self.module.half() + else: + assert all([param.dtype == torch.float for param in self.module.parameters()]), f"The fp16 is not enabled but dtype on parameters not fp16" if not self.dont_change_device: self.module.to(self.device) @@ -1093,7 +1099,8 @@ def clip_fp32_gradients(self): def _take_model_step(self, lr_kwargs): if self.gradient_clipping() > 0.0: - if not self.fp16_enabled() and not self.amp_enabled(): + if not (self.fp16_enabled() or self.amp_enabled() + or self.zero_optimization()): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index d1e29afa0a5d..3cacc524bd69 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -154,7 +154,7 @@ class ZeroParamStatus(Enum): _orig_torch_empty = torch.empty -def empty_cuda_tensor(*size, **kwargs): +def empty_cuda_tensor_half(*size, **kwargs): if not 'device' in kwargs.keys(): kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) tensor = _orig_torch_empty(*size, **kwargs) @@ -164,7 +164,7 @@ def empty_cuda_tensor(*size, **kwargs): return tensor -def new_cuda_tensor(cls, *args): +def new_cuda_tensor_half(cls, *args): device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) tensor = torch.ones((1, 1), device=device).new_empty(*args).half() if tensor.is_floating_point(): @@ -173,6 +173,19 @@ def new_cuda_tensor(cls, *args): return tensor +def empty_cuda_tensor(*size, **kwargs): + if not 'device' in kwargs.keys(): + kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + tensor = _orig_torch_empty(*size, **kwargs) + return tensor + + +def new_cuda_tensor(cls, *args): + device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + tensor = torch.ones((1, 1), device=device).new_empty(*args) + return tensor + + reuse_buffers = False temp_contiguous_tensor = None empty_buffers = {} @@ -181,9 +194,11 @@ def new_cuda_tensor(cls, *args): # Inserts _post_init_method at the end of init method # for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): - def __init__(self, enabled=True, mem_efficient_linear=True): + def __init__(self, enabled=True, mem_efficient_linear=True, config=None, dtype=None): self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled + self._set_dtype(config, dtype) + assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" def __enter__(self): if not self.enabled: @@ -219,8 +234,12 @@ def _init_subclass(cls, **kwargs): # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - torch.Tensor.__new__ = new_cuda_tensor - torch.empty = empty_cuda_tensor + if self.dtype == torch.half: + torch.Tensor.__new__ = new_cuda_tensor_half + torch.empty = empty_cuda_tensor_half + else: + torch.Tensor.__new__ = new_cuda_tensor + torch.empty = empty_cuda_tensor if self.mem_efficient_linear: print_rank_0( @@ -260,6 +279,15 @@ def _disable_class(cls): def _post_init_method(self, module): pass + def _set_dtype(self, ds_config, dtype): + if ds_config is not None and dtype is None: + _ds_config = DeepSpeedConfig(ds_config) + self.dtype = torch.half if _ds_config.fp16_enabled else torch.float + elif dtype is None: + self.dtype = torch.half + else: + self.dtype = dtype + # Replaces all parameters in module with Scattered Parameters class Init(InsertPostInitMethodToModuleSubClasses): @@ -272,7 +300,8 @@ def __init__(self, remote_device=None, pin_memory=False, config=None, - enabled=True): + enabled=True, + dtype=torch.half): """A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the system and converted to half precision. @@ -296,6 +325,8 @@ def __init__(self, for swapping fp16 params to NVMe. enabled (bool, optional): If ``False``, this context has no effect. Defaults to ``True``. + dtype (``torch.dtype``, optional): Can be used to change the data type of the parameters. + Supported options are ``torch.half`` and ``torch.float``. Defaults to ``torch.half`` This context accelerates model initialization and enables models that are too large to allocate in their entirety in CPU memory. It has the @@ -367,7 +398,10 @@ def get_model(): model = deepspeed.zero.Init(module=model) """ - super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear) + super().__init__(enabled=enabled, + mem_efficient_linear=mem_efficient_linear, + config=config, + dtype=dtype) if not torch.distributed.is_initialized(): init_distributed() assert torch.distributed.is_initialized(), "Parameters cannot be scattered without initializing torch.distributed" @@ -632,8 +666,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): f'Before partitioning param {param.ds_id} {param.shape}', force=False) #param.data does not store anything meaningful in partitioned state - param.data = torch.ones(partitioned_param_data_shape).half().to( - param.device) + param.data = torch.ones(1, dtype=self.dtype).to(param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -714,7 +747,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) - param.data = torch.ones(partitioned_param_data_shape).half().to(param.device) + param.data = torch.ones(1, dtype=self.dtype).to(param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 83831643276f..9bf06a585bf1 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -189,6 +189,7 @@ def __init__(self, partition_id = dist.get_rank(group=self.dp_process_group) self.all_reduce_print = False + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype # padding on each partition for alignment purposes self.groups_padding = [] @@ -308,10 +309,12 @@ def __init__(self, self.grad_position = {} self.temp_grad_buffer_for_cpu_offload = torch.zeros( largest_param_numel, - device=self.device).half().pin_memory() + device=self.device, + dtype=self.dtype).pin_memory() self.temp_grad_buffer_for_gpu_offload = torch.zeros( largest_param_numel, - device=torch.cuda.current_device()).half() + device=torch.cuda.current_device(), + dtype=self.dtype) for i, params_group in enumerate(self.fp16_groups): self.get_grad_position(i, @@ -356,7 +359,13 @@ def __init__(self, self.create_reduce_and_remove_grad_hooks() # we may have a way of fusing dynamic scale. Do not support for now - if dynamic_loss_scale: + if self.dtype == torch.float or not dynamic_loss_scale: + loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(scale=loss_scale_value) + cur_iter = 0 + else: if dynamic_loss_args is None: self.loss_scaler = DynamicLossScaler() else: @@ -364,11 +373,6 @@ def __init__(self, self.dynamic_loss_scale = True - else: - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(scale=static_loss_scale) - self.cur_iter = 0 - see_memory_usage("Before initializing optimizer states") self.initialize_optimizer_states() see_memory_usage("After initializing optimizer states") @@ -466,14 +470,14 @@ def independent_gradient_partition_epilogue(self): self.params_in_partition[i], self.first_offset[i], self.partition_size[i], - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device(), return_tensor_list=True) else: avg_new = self.get_flat_partition(self.params_in_partition[i], self.first_offset[i], self.partition_size[i], - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device(), return_tensor_list=True) @@ -946,7 +950,7 @@ def copy_grads_in_partition(self, param): see_memory_usage(f"before copying {total_size} gradients into partition") self.grads_in_partition = torch.empty(int(total_size), - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device()) see_memory_usage(f"after copying {total_size} gradients into partition") @@ -1455,7 +1459,7 @@ def step(self, closure=None): self.start_timers([OPTIMIZER_STEP]) if self.deepspeed_adam_offload: from deepspeed.ops.adam import DeepSpeedCPUAdam - if type(self.optimizer) == DeepSpeedCPUAdam: + if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: fp16_param_groups = [ fp16_partitions[partition_id] for fp16_partitions in self.parallel_partitioned_fp16_groups @@ -1632,14 +1636,14 @@ def backward(self, loss, retain_graph=False): if self.contiguous_gradients: self.ipg_buffer = [] buf_0 = torch.empty(int(self.reduce_bucket_size * 4.5), - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_0) # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: buf_1 = torch.empty(int(self.reduce_bucket_size * 4.5), - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_1) self.ipg_index = 0 diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2b16887ff60d..8b7aee16c4ee 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -640,12 +640,13 @@ def __init__(self, util_ops = UtilsBuilder().load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype if not all(is_zero_param(p) for p in module.parameters()): group = None if mpu: group = mpu.get_data_parallel_group() - Init(module=module, data_parallel_group=group) + Init(module=module, data_parallel_group=group, dtype=self.dtype) for m in module.modules(): _init_external_params(m) @@ -791,7 +792,6 @@ def __init__(self, self.sub_group_size = sub_group_size self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=True) self._create_fp16_partitions_with_defragmentation() num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) @@ -874,10 +874,11 @@ def __init__(self, self.local_overflow = False self.temp_grad_buffer_for_gpu_offload = torch.zeros( largest_partitioned_param_numel, - device=torch.cuda.current_device()).half() - self.temp_grad_gpu_buffer = torch.zeros( - largest_partitioned_param_numel, - device=torch.cuda.current_device()).half() + device=torch.cuda.current_device(), + dtype=self.dtype) + self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel, + device=torch.cuda.current_device(), + dtype=self.dtype) see_memory_usage(f"After CPU Offload initialization", force=False) # stores if a partition has been reduced in this step @@ -895,7 +896,13 @@ def __init__(self, #exit(0) # we may have a way of fusing dynamic scale. Do not support for now - if dynamic_loss_scale: + if self.dtype == torch.float or not dynamic_loss_scale: + loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(scale=loss_scale_value) + cur_iter = 0 + else: if dynamic_loss_args is None: self.loss_scaler = DynamicLossScaler() else: @@ -903,11 +910,6 @@ def __init__(self, self.dynamic_loss_scale = True - else: - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(scale=static_loss_scale) - self.cur_iter = 0 - self.debug_fp16_grads = [{} for _ in self.fp16_groups] if dist.get_rank(group=self.dp_process_group) == 0: @@ -1059,7 +1061,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self): force=False) self.param_groups_fp16_flat_cpu_memory.append( torch.empty(int(flat_buffer_size), - dtype=torch.half, + dtype=self.dtype, pin_memory=True)) else: print_rank_0( @@ -1068,7 +1070,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self): self.param_groups_fp16_flat_cpu_memory.append( torch.empty(1, - dtype=torch.half)) + dtype=self.dtype)) def _create_fp16_partitions_with_defragmentation(self): dist.barrier() @@ -1170,7 +1172,7 @@ def _create_fp16_partitions_with_defragmentation(self): -1] is None and self.param_group_fp16_flat_reuse_buffer is None: self.param_group_fp16_flat_reuse_buffer = torch.empty( max(self.fp16_partitioned_groups_flat_numel), - dtype=torch.half, + dtype=self.dtype, device='cpu', pin_memory=True) @@ -2076,12 +2078,12 @@ def partition_previous_reduced_grads(self): if self.offload_param_pin_memory: self.grads_in_partition.append( torch.zeros(int(total_size), - dtype=torch.half, + dtype=self.dtype, device=self.device).pin_memory()) else: self.grads_in_partition.append( torch.zeros(int(total_size), - dtype=torch.half, + dtype=self.dtype, device=self.device)) see_memory_usage( f"group {i} after creating {total_size} reduced gradients into partition", @@ -2929,14 +2931,14 @@ def backward(self, loss, retain_graph=False): if self.contiguous_gradients: self.ipg_buffer = [] buf_0 = torch.empty(self.reduce_bucket_size, - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_0) # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: buf_1 = torch.empty(self.reduce_bucket_size, - dtype=torch.half, + dtype=self.dtype, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_1) self.ipg_index = 0