diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index bdd1de4cbdda..e0ca4f025957 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -37,7 +37,7 @@ def split_half_float_double(tensors): ] buckets = [] for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t is not None and t.type() == dtype] + bucket = [t for t in tensors if t.type() == dtype] if bucket: buckets.append(bucket) return buckets @@ -477,6 +477,8 @@ def independent_gradient_partition_epilogue(self): if self.overlap_comm: torch.cuda.synchronize() + # It is safe to clear previously reduced grads of other partitions + self._clear_previous_reduced_grads() if self.cpu_offload is False: for i, _ in enumerate(self.fp16_groups): @@ -638,6 +640,9 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): param.grad.data = new_grad_tensor.data.view_as(param.grad) self.elements_in_ipg_bucket += param.numel() + + assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" + self.grads_in_ipg_bucket.append(param.grad) self.params_in_ipg_bucket.append((i, param, param_id)) @@ -965,7 +970,7 @@ def reduce_ipg_grads(self): if not self.is_param_in_current_partition[param_id]: if self.overlap_comm and self.contiguous_gradients is False: - # Clear the previous grads during the next reduction + # Clear grads of other partitions during the next reduction # to avoid clearing them before the reduction is complete. if self.previous_reduced_grads is None: self.previous_reduced_grads = [] @@ -1078,16 +1083,18 @@ def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=N return tensor + def _clear_previous_reduced_grads(self): + if self.previous_reduced_grads is not None: + for param in self.previous_reduced_grads: + param.grad = None + self.previous_reduced_grads = None + #if rank is specified do a reduction instead of an allreduce def allreduce_and_copy(self, small_bucket, rank=None, log=None): if self.overlap_comm: torch.cuda.synchronize() - if self.previous_reduced_grads is not None: - # previous_reduced_grads has the previous reduced grads, - # now it is safe to clear. - for param in self.previous_reduced_grads: - param.grad = None - self.previous_reduced_grads = None + # It is safe to clear the previously reduced grads of other partitions + self._clear_previous_reduced_grads() stream = self.reduction_stream else: stream = torch.cuda.current_stream()