diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6a857bca378c..646e492cc3dd 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -759,7 +759,6 @@ def _configure_zero_optimizer(self, optimizer): timers = self.timers if self.wall_clock_breakdown() else None if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: - assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' optimizer = FP16_DeepSpeedZeroOptimizer_Stage1( optimizer, static_loss_scale=self.loss_scale(), @@ -970,8 +969,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): #Communicate only at gradient accumulation boundaries elif self.is_gradient_accumulation_boundary(): - if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES: - assert self.zero_reduce_scatter() + if self.zero_optimization_stage( + ) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter(): self.optimizer.reduce_scatter_gradients( postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 40b450649850..9cfe313c75b0 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -52,7 +52,7 @@ ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT = True ZERO_OPTIMIZATION_REDUCE_SCATTER = 'reduce_scatter' -ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = True +ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT = False ZERO_OPTIMIZATION_OVERLAP_COMM = 'overlap_comm' ZERO_OPTIMIZATION_OVERLAP_COMM_DEFAULT = False