Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def split_half_float_double(tensors):
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t is not None and t.type() == dtype]
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
Expand Down Expand Up @@ -477,6 +477,8 @@ def independent_gradient_partition_epilogue(self):

if self.overlap_comm:
torch.cuda.synchronize()
# It is safe to clear previously reduced grads of other partitions
self._clear_previous_reduced_grads()

if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):
Expand Down Expand Up @@ -638,6 +640,9 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
param.grad.data = new_grad_tensor.data.view_as(param.grad)

self.elements_in_ipg_bucket += param.numel()

assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))

Expand Down Expand Up @@ -965,7 +970,7 @@ def reduce_ipg_grads(self):

if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear the previous grads during the next reduction
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
Expand Down Expand Up @@ -1078,16 +1083,18 @@ def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=N

return tensor

def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None

#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear.
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None
# It is safe to clear the previously reduced grads of other partitions
self._clear_previous_reduced_grads()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
Expand Down