Skip to content
5 changes: 5 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,8 @@ EarlyStop handler
-----------------
.. autoclass:: EarlyStopHandler
:members:

GarbageCollector handler
------------------------
.. autoclass:: GarbageCollector
:members:
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .classification_saver import ClassificationSaver
from .confusion_matrix import ConfusionMatrix
from .earlystop_handler import EarlyStopHandler
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
from .iteration_metric import IterationMetric
from .lr_schedule_handler import LrScheduleHandler
Expand Down
80 changes: 80 additions & 0 deletions monai/handlers/garbage_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
from typing import TYPE_CHECKING

from monai.utils import exact_version, optional_import

if TYPE_CHECKING:
from ignite.engine import Engine, Events
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")


class GarbageCollector:
"""
Run garbage collector after each epoch

Args:
trigger_event: the event that trigger a call to this handler.
- "epoch", after completion of each epoch (equivalent of ignite.engine.Events.EPOCH_COMPLETED)
- "iteration", after completion of each iteration (equivalent of ignite.engine.Events.ITERATION_COMPLETED)
- any ignite built-in event from ignite.engine.Events.
Defaults to "epoch".
log_level: log level (integer) for some garbage collection information as below. Defaults to 10 (DEBUG).
- 50 (CRITICAL)
- 40 (ERROR)
- 30 (WARNING)
- 20 (INFO)
- 10 (DEBUG)
- 0 (NOTSET)
"""

def __init__(self, trigger_event: str = "epoch", log_level: int = 10):
if isinstance(trigger_event, Events):
self.trigger_event = trigger_event
elif trigger_event.lower() == "epoch":
self.trigger_event = Events.EPOCH_COMPLETED
elif trigger_event.lower() == "iteration":
self.trigger_event = Events.ITERATION_COMPLETED
else:
raise ValueError(
f"'trigger_event' should be either epoch, iteration, or an ignite built-in event from"
f" ignite.engine.Events, '{trigger_event}' was given."
)

self.log_level = log_level

def attach(self, engine: Engine) -> None:
if not engine.has_event_handler(self, self.trigger_event):
engine.add_event_handler(self.trigger_event, self)

def __call__(self, engine: Engine) -> None:
"""
This method calls python garbage collector.

Args:
engine: Ignite Engine, it should be either a trainer or validator.
"""
# get count before garbage collection
pre_count = gc.get_count()
# fits call to garbage collector
gc.collect()
# second call to garbage collector
unreachable = gc.collect()
# get count after garbage collection
after_count = gc.get_count()
engine.logger.log(
self.log_level,
f"Garbage Count: [before: {pre_count}] -> [after: {after_count}] (unreachable : {unreachable})",
)
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def run_testsuit():
"test_handler_confusion_matrix",
"test_handler_confusion_matrix_dist",
"test_handler_hausdorff_distance",
"test_handler_garbage_collector",
"test_handler_mean_dice",
"test_handler_prob_map_producer",
"test_handler_rocauc",
Expand Down
77 changes: 77 additions & 0 deletions tests/test_handler_garbage_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest
from unittest import skipUnless

import torch
from ignite.engine import Engine
from parameterized import parameterized

from monai.data import Dataset
from monai.handlers import GarbageCollector
from monai.utils import exact_version, optional_import

Events, has_ignite = optional_import("ignite.engine", "0.4.4", exact_version, "Events")


TEST_CASE_0 = [[0, 1, 2], "epoch"]

TEST_CASE_1 = [[0, 1, 2], "iteration"]

TEST_CASE_2 = [[0, 1, 2], Events.EPOCH_COMPLETED]


class TestHandlerGarbageCollector(unittest.TestCase):
@skipUnless(has_ignite, "Requires ignite")
@parameterized.expand(
[
TEST_CASE_0,
TEST_CASE_1,
TEST_CASE_2,
]
)
def test_content(self, data, trigger_event):
# set up engine
gb_count_dict = {}

def _train_func(engine, batch):
# store garbage collection counts
if trigger_event == Events.EPOCH_COMPLETED or trigger_event.lower() == "epoch":
if engine.state.iteration % engine.state.epoch_length == 1:
gb_count_dict[engine.state.epoch] = gc.get_count()
elif trigger_event.lower() == "iteration":
gb_count_dict[engine.state.iteration] = gc.get_count()

engine = Engine(_train_func)

# set up testing handler
dataset = Dataset(data, transform=None)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
GarbageCollector(trigger_event=trigger_event, log_level=30).attach(engine)

engine.run(data_loader, max_epochs=5)
print(gb_count_dict)

first_count = 0
for epoch, gb_count in gb_count_dict.items():
# At least one zero-generation object
self.assertGreater(gb_count[0], 0)
if epoch == 1:
first_count = gb_count[0]
else:
# The should be less number of collected objects in the next calls.
self.assertLess(gb_count[0], first_count)


if __name__ == "__main__":
unittest.main()