diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1462225ac2bd..8a9d5abf173f 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -722,6 +722,7 @@ def _configure_zero_optimizer(self, optimizer): zero_stage = self.zero_optimization_stage() log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0]) assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true" + timers = self.timers if self.wall_clock_breakdown() else None if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' @@ -740,7 +741,7 @@ def _configure_zero_optimizer(self, optimizer): elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: optimizer = FP16_DeepSpeedZeroOptimizer( optimizer, - timers=self.timers, + timers=timers, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), @@ -762,7 +763,7 @@ def _configure_zero_optimizer(self, optimizer): optimizer = FP16_DeepSpeedZeroOptimizer_Stage3( self.module, optimizer, - timers=self.timers, + timers=timers, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=self.dynamic_loss_scale_args(), diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 6f3fb1cd6509..bdd1de4cbdda 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -1326,6 +1326,26 @@ def reset_cpu_buffers(self): self.norm_for_param_grads = {} self.local_overflow = False + def log_timers(self, timer_names): + if self.timers is None: + return + + self.timers.log(names=list(timer_names)) + + def start_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).start() + + def stop_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).stop() + def step(self, closure=None): """ Not supporting closure. @@ -1340,7 +1360,10 @@ def step(self, closure=None): # First compute norm for all group so we know if there is overflow self.check_overflow() - timers = self.timers + OPTIMIZER_ALLGATHER = 'optimizer_allgather' + OPTIMIZER_GRADIENTS = 'optimizer_gradients' + OPTIMIZER_STEP = 'optimizer_step' + timer_names = [OPTIMIZER_ALLGATHER, OPTIMIZER_GRADIENTS, OPTIMIZER_STEP] prev_scale = self.loss_scale self._update_scale(self.overflow) @@ -1359,15 +1382,11 @@ def step(self, closure=None): "reducing to {}".format(dist.get_rank(), prev_scale, self.loss_scale)) - timers('optimizer_gradients').start() - timers('optimizer_gradients').stop() - timers('optimizer_step').start() - timers('optimizer_step').stop() - timers('optimizer_allgather').start() - timers('optimizer_allgather').stop() + self.start_timers(timer_names) + self.stop_timers(timer_names) return - timers('optimizer_gradients').start() + self.start_timers([OPTIMIZER_GRADIENTS]) norm_groups = [] single_partition_grad_groups = [] skip = False @@ -1409,10 +1428,9 @@ def step(self, closure=None): single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) - timers('optimizer_gradients').stop() + self.stop_timers([OPTIMIZER_GRADIENTS]) - #torch.set_num_threads(12) - timers('optimizer_step').start() + self.start_timers([OPTIMIZER_STEP]) if self.deepspeed_adam_offload: from deepspeed.ops.adam import DeepSpeedCPUAdam if type(self.optimizer) == DeepSpeedCPUAdam: @@ -1436,12 +1454,12 @@ def step(self, closure=None): for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): fp16_partitions[partition_id].data.copy_(fp32_partition.data) - timers('optimizer_step').stop() + self.stop_timers([OPTIMIZER_STEP]) if self.cpu_offload: self.reset_cpu_buffers() - timers('optimizer_allgather').start() + self.start_timers([OPTIMIZER_ALLGATHER]) #gather the updated weights from everyone for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): @@ -1474,7 +1492,7 @@ def step(self, closure=None): dist.all_gather(shard_list, shard_list[partition_id], group=self.dp_process_group) - timers('optimizer_allgather').stop() + self.stop_timers([OPTIMIZER_ALLGATHER]) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): @@ -1483,11 +1501,9 @@ def step(self, closure=None): for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data - timers.log( - names=['optimizer_gradients', - 'optimizer_step', - 'optimizer_allgather']) + self.log_timers(timer_names) see_memory_usage('After zero_optimizer step') + return def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d2c197fa93c8..f840de15c57d 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -580,7 +580,7 @@ def __init__(self, gradient_accumulation_steps=1, elastic_checkpoint=False): - see_memory_usage("Stage 3 intialize begining", force=True) + see_memory_usage("Stage 3 intialize beginning", force=True) if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") @@ -628,7 +628,7 @@ def __init__(self, self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu' ############################################################################ - see_memory_usage("Before Partitioned Parameter Coordinator", force=True) + see_memory_usage("Before Partitioned Parameter Coordinator", force=False) fetch_stream = torch.cuda.Stream() if self.overlap_comm else None self.param_coordinator = PartitionedParameterCoordinator( @@ -636,7 +636,7 @@ def __init__(self, max_reuse_distance_in_numel=int(max_reuse_distance), max_available_parameters_in_numel=int(max_live_parameters)) - see_memory_usage("After Partitioned Parameter Coordinator", force=True) + see_memory_usage("After Partitioned Parameter Coordinator", force=False) #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) #-------------Stage 3 Setup-------------------# @@ -711,20 +711,20 @@ def __init__(self, self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=True) + see_memory_usage("Before creating fp16 partitions", force=False) #self._create_fp16_partitions() self._create_fp16_partitions_with_defragmentation() num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", - force=True) + force=False) - see_memory_usage("Before creating fp32 partitions", force=True) + see_memory_usage("Before creating fp32 partitions", force=False) self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=True) + see_memory_usage("After creating fp32 partitions", force=False) - see_memory_usage("Before initializing optimizer states", force=True) + see_memory_usage("Before initializing optimizer states", force=False) self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=True) + see_memory_usage("After initializing optimizer states", force=False) if dist.get_rank() == 0: logger.info(f"optimizer state initialized") @@ -767,11 +767,11 @@ def __init__(self, #Largest partitioned param largest_partitioned_param_numel = self._get_largest_partitioned_numel() - see_memory_usage(f"Before Set Grad positions", force=True) + see_memory_usage(f"Before Set Grad positions", force=False) self.grad_position = {} self.set_grad_positions() - see_memory_usage(f"Before CPU Offload initialization", force=True) + see_memory_usage(f"Before CPU Offload initialization", force=False) self.grads_in_partition = None @@ -785,7 +785,7 @@ def __init__(self, self.temp_grad_gpu_buffer = torch.zeros( largest_partitioned_param_numel, device=torch.cuda.current_device()).half() - see_memory_usage(f"After CPU Offload initialization", force=True) + see_memory_usage(f"After CPU Offload initialization", force=False) # stores if a partition has been reduced in this step self.is_partition_reduced = {} @@ -1614,7 +1614,7 @@ def partition_previous_reduced_grads(self): see_memory_usage( f"group {i} before creating {total_size} reduced gradients into partition", - force=True) + force=False) if self.cpu_offload_use_pin_memory: self.grads_in_partition.append( torch.zeros(int(total_size), @@ -1627,7 +1627,7 @@ def partition_previous_reduced_grads(self): device=self.device)) see_memory_usage( f"group {i} after creating {total_size} reduced gradients into partition", - force=True) + force=False) for param in self.previous_reduced_grads: @@ -2044,13 +2044,22 @@ def reset_cpu_buffers(self): self.local_overflow = False def log_timers(self, timer_names): + if self.timers is None: + return + self.timers.log(names=list(timer_names)) def start_timers(self, timer_names): + if self.timers is None: + return + for name in timer_names: self.timers(name).start() def stop_timers(self, timer_names): + if self.timers is None: + return + for name in timer_names: self.timers(name).stop() @@ -2210,7 +2219,7 @@ def old_step(self, closure=None): see_memory_usage('After zero_optimizer step', force=False) print_rank_0(f"------------------Finishing Step-----------------------", - force=True) + force=False) return def _pre_step(self): @@ -2327,7 +2336,7 @@ def _post_step(self, timer_names=set()): self.log_timers(timer_names) - see_memory_usage('After zero_optimizer step', force=True) + see_memory_usage('After zero_optimizer step', force=False) print_rank_0(f"------------------Finishing Step-----------------------") def step(self, closure=None): @@ -2342,7 +2351,6 @@ def step(self, closure=None): norm_groups = self._get_norm_groups() - timers = self.timers timer_names = set() timer_names.add('optimizer_step')