From 1488b0653b7d6dbb559306459bb3e3cd95f4d984 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 9 Feb 2023 12:04:44 +0100 Subject: [PATCH 1/4] Added callable options for iteration_log and epoch_log in StatsHandler Fixes #5964 Signed-off-by: vfdev-5 --- monai/handlers/stats_handler.py | 28 +++++-- tests/test_handler_stats.py | 137 +++++++++++++++++++------------- 2 files changed, 102 insertions(+), 63 deletions(-) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 58917e666b..ea2b299c04 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -66,8 +66,8 @@ class StatsHandler: def __init__( self, - iteration_log: bool = True, - epoch_log: bool = True, + iteration_log: bool | Callable[[Engine, int], bool] = True, + epoch_log: bool | Callable[[Engine, int], bool] = True, epoch_print_logger: Callable[[Engine], Any] | None = None, iteration_print_logger: Callable[[Engine], Any] | None = None, output_transform: Callable = lambda x: x[0], @@ -80,8 +80,14 @@ def __init__( """ Args: - iteration_log: whether to log data when iteration completed, default to `True`. - epoch_log: whether to log data when epoch completed, default to `True`. + iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can + be also a function and it will be interpreted as an event filter + (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details). + Event filter function accepts as input engine and event value (iteration) and should return True/False. + Event filtering can be helpful to customize iteration logging frequency. + epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be + also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more + details. epoch_print_logger: customized callable printer for epoch level logging. Must accept parameter "engine", use default printer if None. iteration_print_logger: customized callable printer for iteration level logging. @@ -135,9 +141,19 @@ def attach(self, engine: Engine) -> None: " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." ) if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + event = ( + Events.ITERATION_COMPLETED(event_filter=self.iteration_log) + if callable(self.iteration_log) + else Events.ITERATION_COMPLETED + ) + engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): - engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + event = ( + Events.EPOCH_COMPLETED(event_filter=self.epoch_log) + if callable(self.epoch_log) + else Events.EPOCH_COMPLETED + ) + engine.add_event_handler(event, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index cb93f93a29..38356a89b2 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -26,74 +26,97 @@ class TestHandlerStats(unittest.TestCase): def test_metrics_print(self): - log_stream = StringIO() - log_handler = logging.StreamHandler(log_stream) - log_handler.setLevel(logging.INFO) - key_to_handler = "test_logging" - key_to_print = "testing_metric" + def event_filter(_, event): + if event in [1, 2]: + return True + return False + + for epoch_log in [True, event_filter]: + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + key_to_print = "testing_metric" - # set up engine - def _train_func(engine, batch): - return [torch.tensor(0.0)] + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] - engine = Engine(_train_func) + engine = Engine(_train_func) - # set up dummy metric - @engine.on(Events.EPOCH_COMPLETED) - def _update_metric(engine): - current_metric = engine.state.metrics.get(key_to_print, 0.1) - engine.state.metrics[key_to_print] = current_metric + 0.1 + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get(key_to_print, 0.1) + engine.state.metrics[key_to_print] = current_metric + 0.1 - # set up testing handler - logger = logging.getLogger(key_to_handler) - logger.setLevel(logging.INFO) - logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler) - stats_handler.attach(engine) - - engine.run(range(3), max_epochs=2) + # set up testing handler + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler) + stats_handler.attach(engine) - # check logging output - output_str = log_stream.getvalue() - log_handler.close() - has_key_word = re.compile(f".*{key_to_print}.*") - content_count = 0 - for line in output_str.split("\n"): - if has_key_word.match(line): - content_count += 1 - self.assertTrue(content_count > 0) + max_epochs = 4 + engine.run(range(3), max_epochs=max_epochs) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + if epoch_log is True: + self.assertTrue(content_count == max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter def test_loss_print(self): - log_stream = StringIO() - log_handler = logging.StreamHandler(log_stream) - log_handler.setLevel(logging.INFO) - key_to_handler = "test_logging" - key_to_print = "myLoss" - - # set up engine - def _train_func(engine, batch): - return [torch.tensor(0.0)] + def event_filter(_, event): + if event in [1, 3]: + return True + return False + + for iteration_log in [True, event_filter]: + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + key_to_print = "myLoss" - engine = Engine(_train_func) + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] - # set up testing handler - logger = logging.getLogger(key_to_handler) - logger.setLevel(logging.INFO) - logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print) - stats_handler.attach(engine) + engine = Engine(_train_func) - engine.run(range(3), max_epochs=2) + # set up testing handler + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler( + iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print + ) + stats_handler.attach(engine) - # check logging output - output_str = log_stream.getvalue() - log_handler.close() - has_key_word = re.compile(f".*{key_to_print}.*") - content_count = 0 - for line in output_str.split("\n"): - if has_key_word.match(line): - content_count += 1 - self.assertTrue(content_count > 0) + num_iters = 3 + max_epochs = 2 + engine.run(range(num_iters), max_epochs=max_epochs) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + if iteration_log is True: + self.assertTrue(content_count == num_iters * max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter def test_loss_dict(self): log_stream = StringIO() From 12460e0b9a43e6438564b827a6c1233351d64b80 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 9 Feb 2023 12:18:41 +0000 Subject: [PATCH 2/4] update based on comments Signed-off-by: Wenqi Li --- tests/test_handler_stats.py | 162 ++++++++++++++++++------------------ 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 38356a89b2..84477f9221 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -20,103 +20,103 @@ import torch from ignite.engine import Engine, Events +from parameterized import parameterized from monai.handlers import StatsHandler +def get_event_filter(e): + def event_filter(_, event): + if event in e: + return True + return False + + return event_filter + + class TestHandlerStats(unittest.TestCase): - def test_metrics_print(self): - def event_filter(_, event): - if event in [1, 2]: - return True - return False - - for epoch_log in [True, event_filter]: - log_stream = StringIO() - log_handler = logging.StreamHandler(log_stream) - log_handler.setLevel(logging.INFO) - key_to_handler = "test_logging" - key_to_print = "testing_metric" + @parameterized.expand([[True], [get_event_filter([1, 2])]]) + def test_metrics_print(self, epoch_log): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + key_to_print = "testing_metric" - # set up engine - def _train_func(engine, batch): - return [torch.tensor(0.0)] + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] - engine = Engine(_train_func) + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get(key_to_print, 0.1) + engine.state.metrics[key_to_print] = current_metric + 0.1 - # set up dummy metric - @engine.on(Events.EPOCH_COMPLETED) - def _update_metric(engine): - current_metric = engine.state.metrics.get(key_to_print, 0.1) - engine.state.metrics[key_to_print] = current_metric + 0.1 + # set up testing handler + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler) + stats_handler.attach(engine) - # set up testing handler - logger = logging.getLogger(key_to_handler) - logger.setLevel(logging.INFO) - logger.addHandler(log_handler) - stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler) - stats_handler.attach(engine) + max_epochs = 4 + engine.run(range(3), max_epochs=max_epochs) - max_epochs = 4 - engine.run(range(3), max_epochs=max_epochs) - - # check logging output - output_str = log_stream.getvalue() - log_handler.close() - has_key_word = re.compile(f".*{key_to_print}.*") - content_count = 0 - for line in output_str.split("\n"): - if has_key_word.match(line): - content_count += 1 - if epoch_log is True: - self.assertTrue(content_count == max_epochs) - else: - self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + if epoch_log is True: + self.assertTrue(content_count == max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter - def test_loss_print(self): - def event_filter(_, event): - if event in [1, 3]: - return True - return False + @parameterized.expand([[True], [get_event_filter([1, 3])]]) + def test_loss_print(self, iteration_log): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + key_to_print = "myLoss" - for iteration_log in [True, event_filter]: - log_stream = StringIO() - log_handler = logging.StreamHandler(log_stream) - log_handler.setLevel(logging.INFO) - key_to_handler = "test_logging" - key_to_print = "myLoss" + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] - # set up engine - def _train_func(engine, batch): - return [torch.tensor(0.0)] + engine = Engine(_train_func) - engine = Engine(_train_func) + # set up testing handler + logger = logging.getLogger(key_to_handler) + logger.setLevel(logging.INFO) + logger.addHandler(log_handler) + stats_handler = StatsHandler( + iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print + ) + stats_handler.attach(engine) - # set up testing handler - logger = logging.getLogger(key_to_handler) - logger.setLevel(logging.INFO) - logger.addHandler(log_handler) - stats_handler = StatsHandler( - iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print - ) - stats_handler.attach(engine) + num_iters = 3 + max_epochs = 2 + engine.run(range(num_iters), max_epochs=max_epochs) - num_iters = 3 - max_epochs = 2 - engine.run(range(num_iters), max_epochs=max_epochs) - - # check logging output - output_str = log_stream.getvalue() - log_handler.close() - has_key_word = re.compile(f".*{key_to_print}.*") - content_count = 0 - for line in output_str.split("\n"): - if has_key_word.match(line): - content_count += 1 - if iteration_log is True: - self.assertTrue(content_count == num_iters * max_epochs) - else: - self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + if iteration_log is True: + self.assertTrue(content_count == num_iters * max_epochs) + else: + self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter def test_loss_dict(self): log_stream = StringIO() From 6d1c71c4ad6407f31091fd46359f585d322d46ed Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 9 Feb 2023 23:06:17 +0100 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: vfdev --- monai/handlers/stats_handler.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index ea2b299c04..340ff2385b 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -141,18 +141,14 @@ def attach(self, engine: Engine) -> None: " please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it." ) if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): - event = ( - Events.ITERATION_COMPLETED(event_filter=self.iteration_log) - if callable(self.iteration_log) - else Events.ITERATION_COMPLETED - ) + event = Events.ITERATION_COMPLETED + if callable(self.iteration_log): # substitute event with new one using filter callable + event = event(event_filter=self.iteration_log) engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): - event = ( - Events.EPOCH_COMPLETED(event_filter=self.epoch_log) - if callable(self.epoch_log) - else Events.EPOCH_COMPLETED - ) + event = Events.EPOCH_COMPLETED + if callable(self.epoch_log): # substitute event with new one using filter callable + event = event(event_filter=self.epoch_log) engine.add_event_handler(event, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised) From 0e16f9d0a91107e90f759664272a3b6bf1b75736 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Feb 2023 22:06:41 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/handlers/stats_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 340ff2385b..8471a87e8e 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -143,12 +143,12 @@ def attach(self, engine: Engine) -> None: if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): event = Events.ITERATION_COMPLETED if callable(self.iteration_log): # substitute event with new one using filter callable - event = event(event_filter=self.iteration_log) + event = event(event_filter=self.iteration_log) engine.add_event_handler(event, self.iteration_completed) if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): event = Events.EPOCH_COMPLETED if callable(self.epoch_log): # substitute event with new one using filter callable - event = event(event_filter=self.epoch_log) + event = event(event_filter=self.epoch_log) engine.add_event_handler(event, self.epoch_completed) if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED): engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)