diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index db186bd73d..207af306d4 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -36,6 +36,7 @@ def __init__( name: Optional[str] = None, epoch_level: bool = True, step_transform: Callable[[Engine], Any] = lambda engine: (), + logger_handler: Optional[logging.Handler] = None, ) -> None: """ Args: @@ -47,6 +48,9 @@ def __init__( `True` is epoch level, `False` is iteration level. step_transform: a callable that is used to transform the information from `engine` to expected input data of lr_scheduler.step() function if necessary. + logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc. + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. + the handler should have a logging level of at least `INFO`. Raises: TypeError: When ``step_transform`` is not ``callable``. @@ -59,6 +63,8 @@ def __init__( if not callable(step_transform): raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.") self.step_transform = step_transform + if logger_handler is not None: + self.logger.addHandler(logger_handler) self._name = name diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index e90d0ebd10..8410aaec87 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -82,7 +82,8 @@ def __init__( tag_name: scalar_value to logger. Defaults to ``'Loss'``. key_var_format: a formatting string to control the output string format of key: value. logger_handler: add additional handler to handle the stats data: save to file, etc. - Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. + the handler should have a logging level of at least `INFO`. """ self.epoch_print_logger = epoch_print_logger diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index b869d2489c..a107cf1cb1 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -570,7 +570,7 @@ def __init__( a typical example is to print some properties of Nifti image: affine, pixdim, etc. additional_info: user can define callable function to extract additional info from input data. logger_handler: add additional handler to output data: save to file, etc. - add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. the handler should have a logging level of at least `INFO`. Raises: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b74e63f683..b611d2ed30 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -795,7 +795,7 @@ def __init__( additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. logger_handler: add additional handler to output data: save to file, etc. - add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html + all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html. the handler should have a logging level of at least `INFO`. allow_missing_keys: don't raise exception if key is missing. diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py index cbb752925b..b260686315 100644 --- a/tests/test_handler_lr_scheduler.py +++ b/tests/test_handler_lr_scheduler.py @@ -10,7 +10,10 @@ # limitations under the License. import logging +import os +import re import sys +import tempfile import unittest import numpy as np @@ -24,6 +27,8 @@ class TestHandlerLrSchedule(unittest.TestCase): def test_content(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) data = [0] * 8 + test_lr = 0.1 + gamma = 0.1 # set up engine def _train_func(engine, batch): @@ -41,24 +46,45 @@ def run_validation(engine): net = torch.nn.PReLU() def _reduce_lr_on_plateau(): - optimizer = torch.optim.SGD(net.parameters(), 0.1) + optimizer = torch.optim.SGD(net.parameters(), test_lr) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1) handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"]) handler.attach(train_engine) - return lr_scheduler + return handler - def _reduce_on_step(): - optimizer = torch.optim.SGD(net.parameters(), 0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) - handler = LrScheduleHandler(lr_scheduler) - handler.attach(train_engine) - return lr_scheduler + with tempfile.TemporaryDirectory() as tempdir: + key_to_handler = "test_log_lr" + key_to_print = "Current learning rate" + filename = os.path.join(tempdir, "test_lr.log") + # test with additional logging handler + file_saver = logging.FileHandler(filename, mode="w") + file_saver.setLevel(logging.INFO) + + def _reduce_on_step(): + optimizer = torch.optim.SGD(net.parameters(), test_lr) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma) + handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver) + handler.attach(train_engine) + handler.logger.setLevel(logging.INFO) + return handler + + schedulers = _reduce_lr_on_plateau(), _reduce_on_step() + + train_engine.run(data, max_epochs=5) + file_saver.close() + schedulers[1].logger.removeHandler(file_saver) - schedulers = _reduce_lr_on_plateau(), _reduce_on_step() + with open(filename) as f: + output_str = f.read() + has_key_word = re.compile(f".*{key_to_print}.*") + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) - train_engine.run(data, max_epochs=5) for scheduler in schedulers: - np.testing.assert_allclose(scheduler._last_lr[0], 0.001) + np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001) if __name__ == "__main__": diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index aa6bc427e1..ee0df74002 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -45,18 +45,19 @@ def _update_metric(engine): # set up testing handler stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [5, 10]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_print(self): log_stream = StringIO() @@ -74,18 +75,19 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_dict(self): log_stream = StringIO() @@ -102,21 +104,22 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler( - name=key_to_handler, output_transform=lambda x: {key_to_print: x}, logger_handler=log_handler + name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}, logger_handler=log_handler ) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_loss_file(self): key_to_handler = "test_logging" @@ -136,18 +139,19 @@ def _train_func(engine, batch): # set up testing handler stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) handler.close() stats_handler.logger.removeHandler(handler) with open(filename) as f: output_str = f.read() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) def test_exception(self): # set up engine @@ -190,17 +194,19 @@ def _update_metric(engine): name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler ) stats_handler.attach(engine) + stats_handler.logger.setLevel(logging.INFO) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() log_handler.close() - grep = re.compile(f".*{key_to_handler}.*") has_key_word = re.compile(".*State values.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line) and idx in [5, 10]: - self.assertTrue(has_key_word.match(line)) + content_count = 0 + for line in output_str.split("\n"): + if has_key_word.match(line): + content_count += 1 + self.assertTrue(content_count > 0) if __name__ == "__main__":