From 6552d85b654205cd1a21a785ba2a27b9cfebd937 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 4 Jan 2022 21:57:22 +0800 Subject: [PATCH 1/8] [DLMED] add log handler Signed-off-by: Nic Ma --- monai/handlers/lr_schedule_handler.py | 6 ++++ monai/handlers/stats_handler.py | 3 +- monai/transforms/utility/array.py | 2 +- monai/transforms/utility/dictionary.py | 2 +- tests/test_handler_lr_scheduler.py | 41 ++++++++++++++++++++------ 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index db186bd73d..207af306d4 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -36,6 +36,7 @@ def __init__( name: Optional[str] = None, epoch_level: bool = True, step_transform: Callable[[Engine], Any] = lambda engine: (), + logger_handler: Optional[logging.Handler] = None, ) -> None: """ Args: @@ -47,6 +48,9 @@ def __init__( `True` is epoch level, `False` is iteration level. step_transform: a callable that is used to transform the information from `engine` to expected input data of lr_scheduler.step() function if necessary. + logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc. + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. + the handler should have a logging level of at least `INFO`. Raises: TypeError: When ``step_transform`` is not ``callable``. @@ -59,6 +63,8 @@ def __init__( if not callable(step_transform): raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.") self.step_transform = step_transform + if logger_handler is not None: + self.logger.addHandler(logger_handler) self._name = name diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index e90d0ebd10..8410aaec87 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -82,7 +82,8 @@ def __init__( tag_name: scalar_value to logger. Defaults to ``'Loss'``. key_var_format: a formatting string to control the output string format of key: value. logger_handler: add additional handler to handle the stats data: save to file, etc. - Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. + the handler should have a logging level of at least `INFO`. """ self.epoch_print_logger = epoch_print_logger diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index b869d2489c..a107cf1cb1 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -570,7 +570,7 @@ def __init__( a typical example is to print some properties of Nifti image: affine, pixdim, etc. additional_info: user can define callable function to extract additional info from input data. logger_handler: add additional handler to output data: save to file, etc. - add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. the handler should have a logging level of at least `INFO`. Raises: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b74e63f683..b611d2ed30 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -795,7 +795,7 @@ def __init__( additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. logger_handler: add additional handler to output data: save to file, etc. - add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. the handler should have a logging level of at least `INFO`. allow_missing_keys: don't raise exception if key is missing. diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index cbb752925b..7c7b1126af 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -10,7 +10,10 @@ # limitations under the License. import logging +import os +import re import sys +import tempfile import unittest import numpy as np @@ -24,6 +27,8 @@ class TestHandlerLrSchedule(unittest.TestCase): def test_content(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 + test_lr = 0.1 + gamma = 0.1 # set up engine def _train_func(engine, batch): @@ -41,22 +46,40 @@ def run_validation(engine): net = torch.nn.PReLU() def _reduce_lr_on_plateau(): - optimizer = torch.optim.SGD(net.parameters(), 0.1) + optimizer = torch.optim.SGD(net.parameters(), test_lr) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1) handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"]) handler.attach(train_engine) return lr_scheduler - def _reduce_on_step(): - optimizer = torch.optim.SGD(net.parameters(), 0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) - handler = LrScheduleHandler(lr_scheduler) - handler.attach(train_engine) - return lr_scheduler + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_lr.log") + # test with additional logging handler + file_saver = logging.FileHandler(filename, mode="w") + file_saver.setLevel(logging.INFO) + + def _reduce_on_step(): + optimizer = torch.optim.SGD(net.parameters(), test_lr) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) + handler = LrScheduleHandler(lr_scheduler, logger_handler=file_saver) + handler.attach(train_engine) + return lr_scheduler + + schedulers = _reduce_lr_on_plateau(), _reduce_on_step() + + train_engine.run(data, max_epochs=5) + file_saver.close() + train_engine.logger.removeHandler(file_saver) - schedulers = _reduce_lr_on_plateau(), _reduce_on_step() + with open(filename) as f: + output_str = f.read() + grep = re.compile(".*Current learning rate.*") + content_count = 0 + for line in output_str.split("\n"): + if grep.match(line): + content_count += 1 + self.assertTrue(content_count > 0) - train_engine.run(data, max_epochs=5) for scheduler in schedulers: np.testing.assert_allclose(scheduler._last_lr[0], 0.001) From a3b4fb79079443646ac05dc86d89e461d1f3f1c3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 4 Jan 2022 23:48:27 +0800 Subject: [PATCH 2/8] [DLMED] fix CI tests Signed-off-by: Nic Ma --- tests/test_handler_lr_scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index 7c7b1126af..c1d8c5a48a 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -50,7 +50,7 @@ def _reduce_lr_on_plateau(): lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1) handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"]) handler.attach(train_engine) - return lr_scheduler + return handler with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_lr.log") @@ -63,13 +63,13 @@ def _reduce_on_step(): lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) handler = LrScheduleHandler(lr_scheduler, logger_handler=file_saver) handler.attach(train_engine) - return lr_scheduler + return handler schedulers = _reduce_lr_on_plateau(), _reduce_on_step() train_engine.run(data, max_epochs=5) file_saver.close() - train_engine.logger.removeHandler(file_saver) + schedulers[1].logger.removeHandler(file_saver) with open(filename) as f: output_str = f.read() @@ -81,7 +81,7 @@ def _reduce_on_step(): self.assertTrue(content_count > 0) for scheduler in schedulers: - np.testing.assert_allclose(scheduler._last_lr[0], 0.001) + np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001) if __name__ == "__main__": From 279d4b4bf9ed66501eaa3a4db554cc33b7c1dafc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 5 Jan 2022 00:17:26 +0800 Subject: [PATCH 3/8] [DLMED] fix CI test Signed-off-by: Nic Ma --- tests/test_handler_lr_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index c1d8c5a48a..b339b89fa1 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -61,7 +61,7 @@ def _reduce_lr_on_plateau(): def _reduce_on_step(): optimizer = torch.optim.SGD(net.parameters(), test_lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) - handler = LrScheduleHandler(lr_scheduler, logger_handler=file_saver) + handler = LrScheduleHandler(lr_scheduler, name="test_logging", logger_handler=file_saver) handler.attach(train_engine) return handler From f0a9539be8ad0665b4b7d95b1f148c73f8a92742 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 5 Jan 2022 00:45:01 +0800 Subject: [PATCH 4/8] [DLMED] test CI Signed-off-by: Nic Ma --- tests/test_handler_lr_scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index b339b89fa1..d1b93c6734 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -73,11 +73,13 @@ def _reduce_on_step(): with open(filename) as f: output_str = f.read() - grep = re.compile(".*Current learning rate.*") + grep = re.compile(".*learning rate.*") content_count = 0 for line in output_str.split("\n"): + print("!!!!!!!", line) if grep.match(line): content_count += 1 + print("!!!!count:", content_count) self.assertTrue(content_count > 0) for scheduler in schedulers: From 87e12f59d4151d21eba588236c49abf94f60ce85 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 5 Jan 2022 07:32:32 +0800 Subject: [PATCH 5/8] [DLMED] fix logging Signed-off-by: Nic Ma --- monai/handlers/lr_schedule_handler.py | 2 ++ monai/handlers/stats_handler.py | 1 + tests/test_handler_lr_scheduler.py | 4 +--- tests/test_handler_stats.py | 3 +++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 207af306d4..63da489f6f 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -75,6 +75,8 @@ def attach(self, engine: Engine) -> None: """ if self._name is None: self.logger = engine.logger + if self.print_lr: + self.logger.setLevel(logging.INFO) if self.epoch_level: engine.add_event_handler(Events.EPOCH_COMPLETED, self) else: diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 8410aaec87..283f487da2 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -109,6 +109,7 @@ def attach(self, engine: Engine) -> None: """ if self._name is None: self.logger = engine.logger + self.logger.setLevel(logging.INFO) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index d1b93c6734..b339b89fa1 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -73,13 +73,11 @@ def _reduce_on_step(): with open(filename) as f: output_str = f.read() - grep = re.compile(".*learning rate.*") + grep = re.compile(".*Current learning rate.*") content_count = 0 for line in output_str.split("\n"): - print("!!!!!!!", line) if grep.match(line): content_count += 1 - print("!!!!count:", content_count) self.assertTrue(content_count > 0) for scheduler in schedulers: diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index aa6bc427e1..dfec06fb8a 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -144,10 +144,13 @@ def _train_func(engine, batch): output_str = f.read() grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 for idx, line in enumerate(output_str.split("\n")): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: + content_count += 1 self.assertTrue(has_key_word.match(line)) + self.assertTrue(content_count > 0) def test_exception(self): # set up engine From 7474ec10cc86088270f7d756a0cab61503d1066e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 5 Jan 2022 08:07:43 +0800 Subject: [PATCH 6/8] [DLMED] temp test Signed-off-by: Nic Ma --- tests/test_handler_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index dfec06fb8a..4ca2ddfd2b 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -146,6 +146,7 @@ def _train_func(engine, batch): has_key_word = re.compile(f".*{key_to_print}.*") content_count = 0 for idx, line in enumerate(output_str.split("\n")): + print("!!!!!!!!", line) # for temp test if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: content_count += 1 From d9ca0bcf57de55c8a36b4aac023d98ea6708af7c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 5 Jan 2022 12:31:23 +0800 Subject: [PATCH 7/8] [DLMED] fix wrong unit test Signed-off-by: Nic Ma --- monai/handlers/lr_schedule_handler.py | 2 -- monai/handlers/stats_handler.py | 1 - tests/test_handler_lr_scheduler.py | 9 ++++++--- tests/test_handler_stats.py | 11 ++++------- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index 63da489f6f..207af306d4 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -75,8 +75,6 @@ def attach(self, engine: Engine) -> None: """ if self._name is None: self.logger = engine.logger - if self.print_lr: - self.logger.setLevel(logging.INFO) if self.epoch_level: engine.add_event_handler(Events.EPOCH_COMPLETED, self) else: diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 283f487da2..8410aaec87 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -109,7 +109,6 @@ def attach(self, engine: Engine) -> None: """ if self._name is None: self.logger = engine.logger - self.logger.setLevel(logging.INFO) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index b339b89fa1..b260686315 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -53,6 +53,8 @@ def _reduce_lr_on_plateau(): return handler with tempfile.TemporaryDirectory() as tempdir: + key_to_handler = "test_log_lr" + key_to_print = "Current learning rate" filename = os.path.join(tempdir, "test_lr.log") # test with additional logging handler file_saver = logging.FileHandler(filename, mode="w") @@ -61,8 +63,9 @@ def _reduce_lr_on_plateau(): def _reduce_on_step(): optimizer = torch.optim.SGD(net.parameters(), test_lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) - handler = LrScheduleHandler(lr_scheduler, name="test_logging", logger_handler=file_saver) + handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver) handler.attach(train_engine) + handler.logger.setLevel(logging.INFO) return handler schedulers = _reduce_lr_on_plateau(), _reduce_on_step() @@ -73,10 +76,10 @@ def _reduce_on_step(): with open(filename) as f: output_str = f.read() - grep = re.compile(".*Current learning rate.*") + has_key_word = re.compile(f".*{key_to_print}.*") content_count = 0 for line in output_str.split("\n"): - if grep.match(line): + if has_key_word.match(line): content_count += 1 self.assertTrue(content_count > 0) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 4ca2ddfd2b..edd7bd1c29 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -136,21 +136,18 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) handler.close() stats_handler.logger.removeHandler(handler) with open(filename) as f: output_str = f.read() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") content_count = 0 - for idx, line in enumerate(output_str.split("\n")): - print("!!!!!!!!", line) # for temp test - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - content_count += 1 - self.assertTrue(has_key_word.match(line)) + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 self.assertTrue(content_count > 0) def test_exception(self): From 5ea7a33e1e580ea0b71dd7aea6921f8afac16495 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 5 Jan 2022 17:32:48 +0800 Subject: [PATCH 8/8] [DLMED] fix wrong test cases Signed-off-by: Nic Ma --- tests/test_handler_stats.py | 45 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index edd7bd1c29..ee0df74002 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -45,18 +45,19 @@ def _update_metric(engine): # set up testing handler stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [5, 10]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_print(self): log_stream = StringIO() @@ -74,18 +75,19 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_dict(self): log_stream = StringIO() @@ -102,21 +104,22 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler( - name=key_to_handler, output_transform=lambda x: {key_to_print: x}, logger_handler=log_handler + name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}, logger_handler=log_handler ) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_file(self): key_to_handler = "test_logging" @@ -191,17 +194,19 @@ def _update_metric(engine): name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler ) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(".*State values.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line) and idx in [5, 10]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) if __name__ == "__main__":