diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 080e7e138c..869467c496 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -115,3 +115,8 @@ EarlyStop handler ----------------- .. autoclass:: EarlyStopHandler :members: + +GarbageCollector handler +------------------------ +.. autoclass:: GarbageCollector + :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index a1f86310ae..2112b074a0 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -14,6 +14,7 @@ from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix from .earlystop_handler import EarlyStopHandler +from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py new file mode 100644 index 0000000000..7bb59c9049 --- /dev/null +++ b/monai/handlers/garbage_collector.py @@ -0,0 +1,80 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +from typing import TYPE_CHECKING + +from monai.utils import exact_version, optional_import + +if TYPE_CHECKING: + from ignite.engine import Engine, Events +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") + + +class GarbageCollector: + """ + Run garbage collector after each epoch + + Args: + trigger_event: the event that trigger a call to this handler. + - "epoch", after completion of each epoch (equivalent of ignite.engine.Events.EPOCH_COMPLETED) + - "iteration", after completion of each iteration (equivalent of ignite.engine.Events.ITERATION_COMPLETED) + - any ignite built-in event from ignite.engine.Events. + Defaults to "epoch". + log_level: log level (integer) for some garbage collection information as below. Defaults to 10 (DEBUG). + - 50 (CRITICAL) + - 40 (ERROR) + - 30 (WARNING) + - 20 (INFO) + - 10 (DEBUG) + - 0 (NOTSET) + """ + + def __init__(self, trigger_event: str = "epoch", log_level: int = 10): + if isinstance(trigger_event, Events): + self.trigger_event = trigger_event + elif trigger_event.lower() == "epoch": + self.trigger_event = Events.EPOCH_COMPLETED + elif trigger_event.lower() == "iteration": + self.trigger_event = Events.ITERATION_COMPLETED + else: + raise ValueError( + f"'trigger_event' should be either epoch, iteration, or an ignite built-in event from" + f" ignite.engine.Events, '{trigger_event}' was given." + ) + + self.log_level = log_level + + def attach(self, engine: Engine) -> None: + if not engine.has_event_handler(self, self.trigger_event): + engine.add_event_handler(self.trigger_event, self) + + def __call__(self, engine: Engine) -> None: + """ + This method calls python garbage collector. + + Args: + engine: Ignite Engine, it should be either a trainer or validator. + """ + # get count before garbage collection + pre_count = gc.get_count() + # fits call to garbage collector + gc.collect() + # second call to garbage collector + unreachable = gc.collect() + # get count after garbage collection + after_count = gc.get_count() + engine.logger.log( + self.log_level, + f"Garbage Count: [before: {pre_count}] -> [after: {after_count}] (unreachable : {unreachable})", + ) diff --git a/tests/min_tests.py b/tests/min_tests.py index 4433081c46..9b96e7eaab 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -42,6 +42,7 @@ def run_testsuit(): "test_handler_confusion_matrix", "test_handler_confusion_matrix_dist", "test_handler_hausdorff_distance", + "test_handler_garbage_collector", "test_handler_mean_dice", "test_handler_prob_map_producer", "test_handler_rocauc", diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py new file mode 100644 index 0000000000..5e6bd7275c --- /dev/null +++ b/tests/test_handler_garbage_collector.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest +from unittest import skipUnless + +import torch +from ignite.engine import Engine +from parameterized import parameterized + +from monai.data import Dataset +from monai.handlers import GarbageCollector +from monai.utils import exact_version, optional_import + +Events, has_ignite = optional_import("ignite.engine", "0.4.4", exact_version, "Events") + + +TEST_CASE_0 = [[0, 1, 2], "epoch"] + +TEST_CASE_1 = [[0, 1, 2], "iteration"] + +TEST_CASE_2 = [[0, 1, 2], Events.EPOCH_COMPLETED] + + +class TestHandlerGarbageCollector(unittest.TestCase): + @skipUnless(has_ignite, "Requires ignite") + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + ] + ) + def test_content(self, data, trigger_event): + # set up engine + gb_count_dict = {} + + def _train_func(engine, batch): + # store garbage collection counts + if trigger_event == Events.EPOCH_COMPLETED or trigger_event.lower() == "epoch": + if engine.state.iteration % engine.state.epoch_length == 1: + gb_count_dict[engine.state.epoch] = gc.get_count() + elif trigger_event.lower() == "iteration": + gb_count_dict[engine.state.iteration] = gc.get_count() + + engine = Engine(_train_func) + + # set up testing handler + dataset = Dataset(data, transform=None) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) + GarbageCollector(trigger_event=trigger_event, log_level=30).attach(engine) + + engine.run(data_loader, max_epochs=5) + print(gb_count_dict) + + first_count = 0 + for epoch, gb_count in gb_count_dict.items(): + # At least one zero-generation object + self.assertGreater(gb_count[0], 0) + if epoch == 1: + first_count = gb_count[0] + else: + # The should be less number of collected objects in the next calls. + self.assertLess(gb_count[0], first_count) + + +if __name__ == "__main__": + unittest.main()