diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index 4d5b906d23..57caee2710 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -35,7 +35,7 @@ class TensorboardPlotHook(ClassyHook): on_start = ClassyHook._noop on_end = ClassyHook._noop - def __init__(self, tb_writer) -> None: + def __init__(self, tb_writer, log_period: int = 10) -> None: """The constructor method of TensorboardPlotHook. Args: @@ -52,13 +52,15 @@ def __init__(self, tb_writer) -> None: self.tb_writer = tb_writer self.learning_rates: Optional[List[float]] = None self.wall_times: Optional[List[float]] = None - self.num_steps_global: Optional[List[int]] = None + self.num_updates: Optional[List[int]] = None + self.log_period = log_period def on_phase_start(self, task: "tasks.ClassyTask") -> None: """Initialize losses and learning_rates.""" self.learning_rates = [] self.wall_times = [] - self.num_steps_global = [] + self.num_updates = [] + self.step_idx = 0 if not is_master(): return @@ -80,11 +82,14 @@ def on_step(self, task: "tasks.ClassyTask") -> None: # Only need to log the average loss during the test phase return - learning_rate_val = task.optimizer.parameters.lr + if self.step_idx % self.log_period == 0: + learning_rate_val = task.optimizer.parameters.lr - self.learning_rates.append(learning_rate_val) - self.wall_times.append(time.time()) - self.num_steps_global.append(task.num_updates) + self.learning_rates.append(learning_rate_val) + self.wall_times.append(time.time()) + self.num_updates.append(task.num_updates) + + self.step_idx += 1 def on_phase_end(self, task: "tasks.ClassyTask") -> None: """Add the losses and learning rates to tensorboard.""" @@ -106,7 +111,7 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None: if task.train: for learning_rate, global_step, wall_time in zip( - self.learning_rates, self.num_steps_global, self.wall_times + self.learning_rates, self.num_updates, self.wall_times ): self.tb_writer.add_scalar( learning_rate_key, diff --git a/test/manual/hooks_tensorboard_plot_hook_test.py b/test/manual/hooks_tensorboard_plot_hook_test.py index 6214c4f398..fb73f07dae 100644 --- a/test/manual/hooks_tensorboard_plot_hook_test.py +++ b/test/manual/hooks_tensorboard_plot_hook_test.py @@ -147,6 +147,7 @@ def flush(self): writer = DummySummaryWriter() hook = TensorboardPlotHook(writer) + hook.log_period = 1 task.set_hooks([hook]) task.optimizer.param_schedulers["lr"] = mock_lr_scheduler