diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py index 858c78095a..f02e48364b 100644 --- a/monai/handlers/garbage_collector.py +++ b/monai/handlers/garbage_collector.py @@ -19,7 +19,11 @@ if TYPE_CHECKING: from ignite.engine import Engine, Events + from ignite.engine.events import CallableEventWithFilter else: + CallableEventWithFilter, _ = optional_import( + "ignite.engine.events", IgniteInfo.OPT_IMPORT_VERSION, min_version, "CallableEventWithFilter" + ) Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -43,9 +47,9 @@ class GarbageCollector: - 0 (NOTSET) """ - def __init__(self, trigger_event: str = "epoch", log_level: int = 10): - self.trigger_event: Events - if isinstance(trigger_event, Events): + def __init__(self, trigger_event: str | Events | CallableEventWithFilter = "epoch", log_level: int = 10): + self.trigger_event: Events | CallableEventWithFilter + if isinstance(trigger_event, Events) or isinstance(trigger_event, CallableEventWithFilter): self.trigger_event = trigger_event elif trigger_event.lower() == "epoch": self.trigger_event = Events.EPOCH_COMPLETED