diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e35335ba0e..3d09cea545 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -104,7 +104,7 @@ def apply_transform( map_items: bool = True, unpack_items: bool = False, log_stats: bool | str = False, - lazy: bool | None = False, + lazy: bool | None = None, overrides: dict | None = None, ) -> list[ReturnType] | ReturnType: """ @@ -124,7 +124,7 @@ def apply_transform( disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the default logger name. Setting it to a string specifies the logger to which errors should be logged. lazy: whether to execute in lazy mode or not. See the :ref:`Lazy Resampling topic for more - information about lazy resampling. + information about lazy resampling. Defaults to None. overrides: optional overrides to apply to transform parameters. This parameter is ignored unless transforms are being executed lazily. See the :ref:`Lazy Resampling topic for more details and examples of its usage. diff --git a/tests/test_compose.py b/tests/test_compose.py index 453ae3868d..a1952b102f 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -607,6 +607,12 @@ def test_compose_with_logger(self, keys, pipeline): "INFO - Pending transforms applied: applied_operations: 1\n" ), ], + [ + mt.OneOf, + (mt.Flip(0),), + False, + ("INFO - Apply pending transforms - lazy: False, pending: 0, " "upcoming 'Flip', transform.lazy: False\n"), + ], ] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 667595caa4..c7c2b77697 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,9 +11,12 @@ from __future__ import annotations +import logging import os import tempfile import unittest +from copy import deepcopy +from io import StringIO import nibabel as nib import numpy as np @@ -21,6 +24,7 @@ from monai.data import Dataset from monai.transforms import Compose, LoadImaged, SimulateDelayd +from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys TEST_CASE_1 = [(128, 128, 128)] @@ -89,6 +93,39 @@ def test_shape(self, expected_shape): for d in data4_list: self.assertTupleEqual(d["image"].shape, expected_shape) + def test_dataset_lazy_on_call(self): + data = np.zeros((1, 5, 5)) + data[0, 0:2, 0:2] = 1 + + +class TestDatsesetWithLazy(unittest.TestCase): + LOGGER_NAME = "a_logger_name" + + def init_logger(self, name=LOGGER_NAME): + stream = StringIO() + handler = logging.StreamHandler(stream) + formatter = logging.Formatter("%(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + while len(logger.handlers) > 0: + logger.removeHandler(logger.handlers[-1]) + logger.addHandler(handler) + return handler, stream + + @parameterized.expand(TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES) + def test_dataset_lazy_with_logging(self, compose_type, pipeline, lazy, expected): + handler, stream = self.init_logger(name=self.LOGGER_NAME) + + data = data_from_keys(None, 12, 16) + c = compose_type(deepcopy(pipeline), log_stats=self.LOGGER_NAME, lazy=lazy) + ds = Dataset([data], transform=c) + ds[0] + + handler.flush() + actual = stream.getvalue() + self.assertEqual(actual, expected) + if __name__ == "__main__": unittest.main()