From dfaa6a01412d860f6bfed707fc2e56639228edcb Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 23 Oct 2023 15:44:41 +0100 Subject: [PATCH 01/13] Skeletion factory class --- monai/networks/layers/factories.py | 4 +- monai/utils/__init__.py | 2 + monai/utils/component_store.py | 118 +++++++++++++++++++++++++++++ monai/utils/factory.py | 102 +++++++++++++++++++++++++ tests/test_component_store.py | 72 ++++++++++++++++++ 5 files changed, 296 insertions(+), 2 deletions(-) create mode 100644 monai/utils/component_store.py create mode 100644 monai/utils/factory.py create mode 100644 tests/test_component_store.py diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index bb56b0c0c5..78d23761b2 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -68,12 +68,12 @@ def use_factory(fact_args): import torch.nn as nn from monai.networks.utils import has_nvfuser_instance_norm -from monai.utils import look_up_option, optional_import +from monai.utils import Factory, look_up_option, optional_import __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] -class LayerFactory: +class LayerFactory(Factory): """ Factory object for creating layers, this uses given factory functions to actually produce the types or constructing callables. These functions are referred to by name and can be added at any time. diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c973d4bfa1..a558c38ef0 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -13,6 +13,7 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name +from .component_store import ComponentStore from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather @@ -61,6 +62,7 @@ Weight, WSIPatchKeys, ) +from .factory import Factory from .jupyter_utils import StatusMembers, ThreadContainer from .misc import ( MAX_SEED, diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py new file mode 100644 index 0000000000..67ded2c321 --- /dev/null +++ b/monai/utils/component_store.py @@ -0,0 +1,118 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import namedtuple +from keyword import iskeyword +from textwrap import dedent, indent +from typing import Any, Callable, Iterable, TypeVar + +T = TypeVar("T") +from monai.utils.factory import Factory + + +def is_variable(name): + """Returns True if `name` is a valid Python variable name and also not a keyword.""" + return name.isidentifier() and not iskeyword(name) + + +class ComponentStore(Factory): + """ + Represents a storage object for other objects (specifically functions) keyed to a name with a description. + + These objects act as global named places for storing components for objects parameterised by component names. + Typically this is functions although other objects can be added. Printing a component store will produce a + list of members along with their docstring information if present. + + Example: + + .. code-block:: python + + TestStore = ComponentStore("Test Store", "A test store for demo purposes") + + @TestStore.add_def("my_func_name", "Some description of your function") + def _my_func(a, b): + '''A description of your function here.''' + return a * b + + print(TestStore) # will print out name, description, and 'my_func_name' with the docstring + + func = TestStore["my_func_name"] + result = func(7, 6) + + """ + + _Component = namedtuple("Component", ("description", "value")) # internal value pair + + def __init__(self, name: str, description: str) -> None: + self.components: dict[str, self._Component] = {} + self.name: str = name + self.description: str = description + + self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() + + def add(self, name: str, desc: str, value: T) -> T: + """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): + raise ValueError("Name of component must be valid Python identifier") + + self.components[name] = self._Component(desc, value) + return value + + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.components.items(): + yield k, v.value + + def __str__(self): + result = f"Component Store '{self.name}': {self.description}\nAvailable components:" + for k, v in self.components.items(): + result += f"\n* {k}:" + + if hasattr(v.value, "__doc__"): + doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") + result += f"\n{doc}\n" + else: + result += f" {v.description}" + + return result + + def __getattr__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + return self.__getattribute__(name) + + def __getitem__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + raise ValueError(f"Component '{name}' not found") diff --git a/monai/utils/factory.py b/monai/utils/factory.py new file mode 100644 index 0000000000..97b891ca30 --- /dev/null +++ b/monai/utils/factory.py @@ -0,0 +1,102 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Defines a generic factory class. +""" + +from __future__ import annotations + + +class Factory: + """ + Baseline factory object. + """ + + # def __init__(self) -> None: + # self.factories: dict[str, Callable] = {} + # + # @property + # def names(self) -> tuple[str, ...]: + # """ + # Produces all factory names. + # """ + # + # return tuple(self.factories) + # + # def add_factory_callable(self, name: str, func: Callable) -> None: + # """ + # Add the factory function to this object under the given name. + # """ + # + # self.factories[name.upper()] = func + # self.__doc__ = ( + # "The supported member" + # + ("s are: " if len(self.names) > 1 else " is: ") + # + ", ".join(f"``{name}``" for name in self.names) + # + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." + # ) + # + # def factory_function(self, name: str) -> Callable: + # """ + # Decorator for adding a factory function with the given name. + # """ + # + # def _add(func: Callable) -> Callable: + # self.add_factory_callable(name, func) + # return func + # + # return _add + # + # def get_constructor(self, factory_name: str, *args) -> Any: + # """ + # Get the constructor for the given factory name and arguments. + # + # Raises: + # TypeError: When ``factory_name`` is not a ``str``. + # + # """ + # + # if not isinstance(factory_name, str): + # raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.") + # + # func = look_up_option(factory_name.upper(), self.factories) + # return func(*args) + # + # def __getitem__(self, args) -> Any: + # """ + # Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor + # itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments. + # """ + # + # # `args[0]` is actually a type or constructor + # if callable(args): + # return args + # + # # `args` is a factory name or a name with arguments + # if isinstance(args, str): + # name_obj, args = args, () + # else: + # name_obj, *args = args + # + # return self.get_constructor(name_obj, *args) + # + # def __getattr__(self, key): + # """ + # If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names + # as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo. + # """ + # + # if key in self.factories: + # return key + # + # return super().__getattribute__(key) + # + # diff --git a/tests/test_component_store.py b/tests/test_component_store.py new file mode 100644 index 0000000000..614f387754 --- /dev/null +++ b/tests/test_component_store.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.utils import ComponentStore + + +class TestComponentStore(unittest.TestCase): + def setUp(self): + self.cs = ComponentStore("TestStore", "I am a test store, please ignore") + + def test_empty(self): + self.assertEqual(len(self.cs), 0) + self.assertEqual(list(self.cs), []) + + def test_add(self): + test_obj = object() + + self.assertFalse("test_obj" in self.cs) + + self.cs.add("test_obj", "Test object", test_obj) + + self.assertTrue("test_obj" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_obj", test_obj)]) + + self.assertEqual(self.cs.test_obj, test_obj) + self.assertEqual(self.cs["test_obj"], test_obj) + + def test_add2(self): + test_obj1 = object() + test_obj2 = object() + + self.cs.add("test_obj1", "Test object", test_obj1) + self.cs.add("test_obj2", "Test object", test_obj2) + + self.assertEqual(len(self.cs), 2) + self.assertTrue("test_obj1" in self.cs) + self.assertTrue("test_obj2" in self.cs) + + def test_add_def(self): + self.assertFalse("test_func" in self.cs) + + @self.cs.add_def("test_func", "Test function") + def test_func(): + return 123 + + self.assertTrue("test_func" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_func", test_func)]) + + self.assertEqual(self.cs.test_func, test_func) + self.assertEqual(self.cs["test_func"], test_func) + + # try adding the same function again + self.cs.add_def("test_func", "Test function but with new description")(test_func) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(self.cs.test_func, test_func) From 99ecb5a4890b9e951d1ad495897e6f408a29ab83 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 24 Oct 2023 10:40:06 +0100 Subject: [PATCH 02/13] Renaming and moving some functionality into base Factory class --- monai/networks/layers/factories.py | 94 ++++++++++-------------- monai/utils/component_store.py | 39 +++------- monai/utils/factory.py | 114 +++++++++-------------------- tests/test_component_store.py | 4 +- 4 files changed, 88 insertions(+), 163 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 78d23761b2..2917d18858 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -63,7 +63,7 @@ def use_factory(fact_args): import warnings from collections.abc import Callable -from typing import Any +from typing import Any, Iterable import torch.nn as nn @@ -82,15 +82,7 @@ class LayerFactory(Factory): def __init__(self) -> None: self.factories: dict[str, Callable] = {} - @property - def names(self) -> tuple[str, ...]: - """ - Produces all factory names. - """ - - return tuple(self.factories) - - def add_factory_callable(self, name: str, func: Callable) -> None: + def add(self, name: str, func: Callable) -> None: """ Add the factory function to this object under the given name. """ @@ -103,17 +95,6 @@ def add_factory_callable(self, name: str, func: Callable) -> None: + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." ) - def factory_function(self, name: str) -> Callable: - """ - Decorator for adding a factory function with the given name. - """ - - def _add(func: Callable) -> Callable: - self.add_factory_callable(name, func) - return func - - return _add - def get_constructor(self, factory_name: str, *args) -> Any: """ Get the constructor for the given factory name and arguments. @@ -158,6 +139,11 @@ def __getattr__(self, key): return super().__getattribute__(key) + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.factories.items(): + yield k, v + def split_args(args): """ @@ -203,50 +189,50 @@ def split_args(args): Pad = LayerFactory() -@Dropout.factory_function("dropout") +@Dropout.factory_item("dropout") def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]: types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d) return types[dim - 1] -@Dropout.factory_function("alphadropout") +@Dropout.factory_item("alphadropout") def alpha_dropout_factory(_dim): return nn.AlphaDropout -@Norm.factory_function("instance") +@Norm.factory_item("instance") def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]: types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) return types[dim - 1] -@Norm.factory_function("batch") +@Norm.factory_item("batch") def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]: types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) return types[dim - 1] -@Norm.factory_function("group") +@Norm.factory_item("group") def group_factory(_dim) -> type[nn.GroupNorm]: return nn.GroupNorm -@Norm.factory_function("layer") +@Norm.factory_item("layer") def layer_factory(_dim) -> type[nn.LayerNorm]: return nn.LayerNorm -@Norm.factory_function("localresponse") +@Norm.factory_item("localresponse") def local_response_factory(_dim) -> type[nn.LocalResponseNorm]: return nn.LocalResponseNorm -@Norm.factory_function("syncbatch") +@Norm.factory_item("syncbatch") def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]: return nn.SyncBatchNorm -@Norm.factory_function("instance_nvfuser") +@Norm.factory_item("instance_nvfuser") def instance_nvfuser_factory(dim): """ `InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`. @@ -274,91 +260,91 @@ def instance_nvfuser_factory(dim): return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0] -Act.add_factory_callable("elu", lambda: nn.modules.ELU) -Act.add_factory_callable("relu", lambda: nn.modules.ReLU) -Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) -Act.add_factory_callable("prelu", lambda: nn.modules.PReLU) -Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6) -Act.add_factory_callable("selu", lambda: nn.modules.SELU) -Act.add_factory_callable("celu", lambda: nn.modules.CELU) -Act.add_factory_callable("gelu", lambda: nn.modules.GELU) -Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid) -Act.add_factory_callable("tanh", lambda: nn.modules.Tanh) -Act.add_factory_callable("softmax", lambda: nn.modules.Softmax) -Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax) +Act.add("elu", lambda: nn.modules.ELU) +Act.add("relu", lambda: nn.modules.ReLU) +Act.add("leakyrelu", lambda: nn.modules.LeakyReLU) +Act.add("prelu", lambda: nn.modules.PReLU) +Act.add("relu6", lambda: nn.modules.ReLU6) +Act.add("selu", lambda: nn.modules.SELU) +Act.add("celu", lambda: nn.modules.CELU) +Act.add("gelu", lambda: nn.modules.GELU) +Act.add("sigmoid", lambda: nn.modules.Sigmoid) +Act.add("tanh", lambda: nn.modules.Tanh) +Act.add("softmax", lambda: nn.modules.Softmax) +Act.add("logsoftmax", lambda: nn.modules.LogSoftmax) -@Act.factory_function("swish") +@Act.factory_item("swish") def swish_factory(): from monai.networks.blocks.activation import Swish return Swish -@Act.factory_function("memswish") +@Act.factory_item("memswish") def memswish_factory(): from monai.networks.blocks.activation import MemoryEfficientSwish return MemoryEfficientSwish -@Act.factory_function("mish") +@Act.factory_item("mish") def mish_factory(): from monai.networks.blocks.activation import Mish return Mish -@Act.factory_function("geglu") +@Act.factory_item("geglu") def geglu_factory(): from monai.networks.blocks.activation import GEGLU return GEGLU -@Conv.factory_function("conv") +@Conv.factory_item("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) return types[dim - 1] -@Conv.factory_function("convtrans") +@Conv.factory_item("convtrans") def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]: types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) return types[dim - 1] -@Pool.factory_function("max") +@Pool.factory_item("max") def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]: types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d) return types[dim - 1] -@Pool.factory_function("adaptivemax") +@Pool.factory_item("adaptivemax") def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]: types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d) return types[dim - 1] -@Pool.factory_function("avg") +@Pool.factory_item("avg") def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]: types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d) return types[dim - 1] -@Pool.factory_function("adaptiveavg") +@Pool.factory_item("adaptiveavg") def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]: types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d) return types[dim - 1] -@Pad.factory_function("replicationpad") +@Pad.factory_item("replicationpad") def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]: types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d) return types[dim - 1] -@Pad.factory_function("constantpad") +@Pad.factory_item("constantpad") def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]: types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py index 67ded2c321..c60d003cf7 100644 --- a/monai/utils/component_store.py +++ b/monai/utils/component_store.py @@ -14,11 +14,12 @@ from collections import namedtuple from keyword import iskeyword from textwrap import dedent, indent -from typing import Any, Callable, Iterable, TypeVar +from typing import Any, Iterable, TypeVar -T = TypeVar("T") from monai.utils.factory import Factory +T = TypeVar("T") + def is_variable(name): """Returns True if `name` is a valid Python variable name and also not a keyword.""" @@ -54,7 +55,7 @@ def _my_func(a, b): _Component = namedtuple("Component", ("description", "value")) # internal value pair def __init__(self, name: str, description: str) -> None: - self.components: dict[str, self._Component] = {} + self.factories: dict[str, self._Component] = {} self.name: str = name self.description: str = description @@ -62,37 +63,21 @@ def __init__(self, name: str, description: str) -> None: def add(self, name: str, desc: str, value: T) -> T: """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): raise ValueError("Name of component must be valid Python identifier") - self.components[name] = self._Component(desc, value) + self.factories[name] = self._Component(desc, value) return value - def add_def(self, name: str, desc: str) -> Callable: - """Returns a decorator which stores the decorated function under `name` with description `desc`.""" - - def deco(func): - """Decorator to add a function to a store.""" - return self.add(name, desc, func) - - return deco - - def __contains__(self, name: str) -> bool: - """Returns True if the given name is stored.""" - return name in self.components - - def __len__(self) -> int: - """Returns the number of stored components.""" - return len(self.components) - def __iter__(self) -> Iterable: """Yields name/component pairs.""" - for k, v in self.components.items(): + for k, v in self.factories.items(): yield k, v.value def __str__(self): result = f"Component Store '{self.name}': {self.description}\nAvailable components:" - for k, v in self.components.items(): + for k, v in self.factories.items(): result += f"\n* {k}:" if hasattr(v.value, "__doc__"): @@ -105,14 +90,14 @@ def __str__(self): def __getattr__(self, name: str) -> Any: """Returns the stored object under the given name.""" - if name in self.components: - return self.components[name].value + if name in self.factories: + return self.factories[name].value else: return self.__getattribute__(name) def __getitem__(self, name: str) -> Any: """Returns the stored object under the given name.""" - if name in self.components: - return self.components[name].value + if name in self.factories: + return self.factories[name].value else: raise ValueError(f"Component '{name}' not found") diff --git a/monai/utils/factory.py b/monai/utils/factory.py index 97b891ca30..0d1cafde07 100644 --- a/monai/utils/factory.py +++ b/monai/utils/factory.py @@ -14,89 +14,43 @@ from __future__ import annotations +from typing import Callable + class Factory: """ Baseline factory object. """ - # def __init__(self) -> None: - # self.factories: dict[str, Callable] = {} - # - # @property - # def names(self) -> tuple[str, ...]: - # """ - # Produces all factory names. - # """ - # - # return tuple(self.factories) - # - # def add_factory_callable(self, name: str, func: Callable) -> None: - # """ - # Add the factory function to this object under the given name. - # """ - # - # self.factories[name.upper()] = func - # self.__doc__ = ( - # "The supported member" - # + ("s are: " if len(self.names) > 1 else " is: ") - # + ", ".join(f"``{name}``" for name in self.names) - # + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." - # ) - # - # def factory_function(self, name: str) -> Callable: - # """ - # Decorator for adding a factory function with the given name. - # """ - # - # def _add(func: Callable) -> Callable: - # self.add_factory_callable(name, func) - # return func - # - # return _add - # - # def get_constructor(self, factory_name: str, *args) -> Any: - # """ - # Get the constructor for the given factory name and arguments. - # - # Raises: - # TypeError: When ``factory_name`` is not a ``str``. - # - # """ - # - # if not isinstance(factory_name, str): - # raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.") - # - # func = look_up_option(factory_name.upper(), self.factories) - # return func(*args) - # - # def __getitem__(self, args) -> Any: - # """ - # Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor - # itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments. - # """ - # - # # `args[0]` is actually a type or constructor - # if callable(args): - # return args - # - # # `args` is a factory name or a name with arguments - # if isinstance(args, str): - # name_obj, args = args, () - # else: - # name_obj, *args = args - # - # return self.get_constructor(name_obj, *args) - # - # def __getattr__(self, key): - # """ - # If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names - # as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo. - # """ - # - # if key in self.factories: - # return key - # - # return super().__getattribute__(key) - # - # + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.factories) + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.factories + + @property + def names(self) -> tuple[str, ...]: + """ + Produces all factory names. + """ + + return tuple(self.factories) + + def add(self, *args) -> None: + """ + Add a factory item. + """ + raise NotImplementedError + + def factory_item(self, *args) -> Callable: + """ + Decorator for adding a factory item with the given name and other associated information. + """ + + def _add(func: Callable) -> Callable: + self.add(*args, func) + return func + + return _add diff --git a/tests/test_component_store.py b/tests/test_component_store.py index 614f387754..7cd8dbf0fa 100644 --- a/tests/test_component_store.py +++ b/tests/test_component_store.py @@ -53,7 +53,7 @@ def test_add2(self): def test_add_def(self): self.assertFalse("test_func" in self.cs) - @self.cs.add_def("test_func", "Test function") + @self.cs.factory_item("test_func", "Test function") def test_func(): return 123 @@ -66,7 +66,7 @@ def test_func(): self.assertEqual(self.cs["test_func"], test_func) # try adding the same function again - self.cs.add_def("test_func", "Test function but with new description")(test_func) + self.cs.factory_item("test_func", "Test function but with new description")(test_func) self.assertEqual(len(self.cs), 1) self.assertEqual(self.cs.test_func, test_func) From 664d6a4c56b3156ffd49cb3dc4a5da00e9ef7411 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:41:26 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/factories.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 2917d18858..007ea82f24 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -141,8 +141,7 @@ def __getattr__(self, key): def __iter__(self) -> Iterable: """Yields name/component pairs.""" - for k, v in self.factories.items(): - yield k, v + yield from self.factories.items() def split_args(args): From 8ebd0b6656f14fda0222a95766109601f799b5db Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 25 Oct 2023 16:20:49 +0100 Subject: [PATCH 04/13] Revert to component store and layer factory as they were --- monai/networks/layers/factories.py | 98 +++++++++++++++++------------- monai/utils/component_store.py | 40 ++++++++---- tests/test_component_store.py | 4 +- 3 files changed, 85 insertions(+), 57 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 2917d18858..bb56b0c0c5 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -63,17 +63,17 @@ def use_factory(fact_args): import warnings from collections.abc import Callable -from typing import Any, Iterable +from typing import Any import torch.nn as nn from monai.networks.utils import has_nvfuser_instance_norm -from monai.utils import Factory, look_up_option, optional_import +from monai.utils import look_up_option, optional_import __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] -class LayerFactory(Factory): +class LayerFactory: """ Factory object for creating layers, this uses given factory functions to actually produce the types or constructing callables. These functions are referred to by name and can be added at any time. @@ -82,7 +82,15 @@ class LayerFactory(Factory): def __init__(self) -> None: self.factories: dict[str, Callable] = {} - def add(self, name: str, func: Callable) -> None: + @property + def names(self) -> tuple[str, ...]: + """ + Produces all factory names. + """ + + return tuple(self.factories) + + def add_factory_callable(self, name: str, func: Callable) -> None: """ Add the factory function to this object under the given name. """ @@ -95,6 +103,17 @@ def add(self, name: str, func: Callable) -> None: + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." ) + def factory_function(self, name: str) -> Callable: + """ + Decorator for adding a factory function with the given name. + """ + + def _add(func: Callable) -> Callable: + self.add_factory_callable(name, func) + return func + + return _add + def get_constructor(self, factory_name: str, *args) -> Any: """ Get the constructor for the given factory name and arguments. @@ -139,11 +158,6 @@ def __getattr__(self, key): return super().__getattribute__(key) - def __iter__(self) -> Iterable: - """Yields name/component pairs.""" - for k, v in self.factories.items(): - yield k, v - def split_args(args): """ @@ -189,50 +203,50 @@ def split_args(args): Pad = LayerFactory() -@Dropout.factory_item("dropout") +@Dropout.factory_function("dropout") def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]: types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d) return types[dim - 1] -@Dropout.factory_item("alphadropout") +@Dropout.factory_function("alphadropout") def alpha_dropout_factory(_dim): return nn.AlphaDropout -@Norm.factory_item("instance") +@Norm.factory_function("instance") def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]: types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) return types[dim - 1] -@Norm.factory_item("batch") +@Norm.factory_function("batch") def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]: types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) return types[dim - 1] -@Norm.factory_item("group") +@Norm.factory_function("group") def group_factory(_dim) -> type[nn.GroupNorm]: return nn.GroupNorm -@Norm.factory_item("layer") +@Norm.factory_function("layer") def layer_factory(_dim) -> type[nn.LayerNorm]: return nn.LayerNorm -@Norm.factory_item("localresponse") +@Norm.factory_function("localresponse") def local_response_factory(_dim) -> type[nn.LocalResponseNorm]: return nn.LocalResponseNorm -@Norm.factory_item("syncbatch") +@Norm.factory_function("syncbatch") def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]: return nn.SyncBatchNorm -@Norm.factory_item("instance_nvfuser") +@Norm.factory_function("instance_nvfuser") def instance_nvfuser_factory(dim): """ `InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`. @@ -260,91 +274,91 @@ def instance_nvfuser_factory(dim): return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0] -Act.add("elu", lambda: nn.modules.ELU) -Act.add("relu", lambda: nn.modules.ReLU) -Act.add("leakyrelu", lambda: nn.modules.LeakyReLU) -Act.add("prelu", lambda: nn.modules.PReLU) -Act.add("relu6", lambda: nn.modules.ReLU6) -Act.add("selu", lambda: nn.modules.SELU) -Act.add("celu", lambda: nn.modules.CELU) -Act.add("gelu", lambda: nn.modules.GELU) -Act.add("sigmoid", lambda: nn.modules.Sigmoid) -Act.add("tanh", lambda: nn.modules.Tanh) -Act.add("softmax", lambda: nn.modules.Softmax) -Act.add("logsoftmax", lambda: nn.modules.LogSoftmax) +Act.add_factory_callable("elu", lambda: nn.modules.ELU) +Act.add_factory_callable("relu", lambda: nn.modules.ReLU) +Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) +Act.add_factory_callable("prelu", lambda: nn.modules.PReLU) +Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6) +Act.add_factory_callable("selu", lambda: nn.modules.SELU) +Act.add_factory_callable("celu", lambda: nn.modules.CELU) +Act.add_factory_callable("gelu", lambda: nn.modules.GELU) +Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid) +Act.add_factory_callable("tanh", lambda: nn.modules.Tanh) +Act.add_factory_callable("softmax", lambda: nn.modules.Softmax) +Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax) -@Act.factory_item("swish") +@Act.factory_function("swish") def swish_factory(): from monai.networks.blocks.activation import Swish return Swish -@Act.factory_item("memswish") +@Act.factory_function("memswish") def memswish_factory(): from monai.networks.blocks.activation import MemoryEfficientSwish return MemoryEfficientSwish -@Act.factory_item("mish") +@Act.factory_function("mish") def mish_factory(): from monai.networks.blocks.activation import Mish return Mish -@Act.factory_item("geglu") +@Act.factory_function("geglu") def geglu_factory(): from monai.networks.blocks.activation import GEGLU return GEGLU -@Conv.factory_item("conv") +@Conv.factory_function("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) return types[dim - 1] -@Conv.factory_item("convtrans") +@Conv.factory_function("convtrans") def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]: types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) return types[dim - 1] -@Pool.factory_item("max") +@Pool.factory_function("max") def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]: types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d) return types[dim - 1] -@Pool.factory_item("adaptivemax") +@Pool.factory_function("adaptivemax") def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]: types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d) return types[dim - 1] -@Pool.factory_item("avg") +@Pool.factory_function("avg") def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]: types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d) return types[dim - 1] -@Pool.factory_item("adaptiveavg") +@Pool.factory_function("adaptiveavg") def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]: types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d) return types[dim - 1] -@Pad.factory_item("replicationpad") +@Pad.factory_function("replicationpad") def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]: types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d) return types[dim - 1] -@Pad.factory_item("constantpad") +@Pad.factory_function("constantpad") def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]: types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py index c60d003cf7..6fd8e8884f 100644 --- a/monai/utils/component_store.py +++ b/monai/utils/component_store.py @@ -14,9 +14,7 @@ from collections import namedtuple from keyword import iskeyword from textwrap import dedent, indent -from typing import Any, Iterable, TypeVar - -from monai.utils.factory import Factory +from typing import Any, Callable, Iterable, TypeVar T = TypeVar("T") @@ -26,7 +24,7 @@ def is_variable(name): return name.isidentifier() and not iskeyword(name) -class ComponentStore(Factory): +class ComponentStore: """ Represents a storage object for other objects (specifically functions) keyed to a name with a description. @@ -55,7 +53,7 @@ def _my_func(a, b): _Component = namedtuple("Component", ("description", "value")) # internal value pair def __init__(self, name: str, description: str) -> None: - self.factories: dict[str, self._Component] = {} + self.components: dict[str, self._Component] = {} self.name: str = name self.description: str = description @@ -63,21 +61,37 @@ def __init__(self, name: str, description: str) -> None: def add(self, name: str, desc: str, value: T) -> T: """Store the object `value` under the name `name` with description `desc`.""" - if not is_variable(name): raise ValueError("Name of component must be valid Python identifier") - self.factories[name] = self._Component(desc, value) + self.components[name] = self._Component(desc, value) return value + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + def __iter__(self) -> Iterable: """Yields name/component pairs.""" - for k, v in self.factories.items(): + for k, v in self.components.items(): yield k, v.value def __str__(self): result = f"Component Store '{self.name}': {self.description}\nAvailable components:" - for k, v in self.factories.items(): + for k, v in self.components.items(): result += f"\n* {k}:" if hasattr(v.value, "__doc__"): @@ -90,14 +104,14 @@ def __str__(self): def __getattr__(self, name: str) -> Any: """Returns the stored object under the given name.""" - if name in self.factories: - return self.factories[name].value + if name in self.components: + return self.components[name].value else: return self.__getattribute__(name) def __getitem__(self, name: str) -> Any: """Returns the stored object under the given name.""" - if name in self.factories: - return self.factories[name].value + if name in self.components: + return self.components[name].value else: raise ValueError(f"Component '{name}' not found") diff --git a/tests/test_component_store.py b/tests/test_component_store.py index 7cd8dbf0fa..614f387754 100644 --- a/tests/test_component_store.py +++ b/tests/test_component_store.py @@ -53,7 +53,7 @@ def test_add2(self): def test_add_def(self): self.assertFalse("test_func" in self.cs) - @self.cs.factory_item("test_func", "Test function") + @self.cs.add_def("test_func", "Test function") def test_func(): return 123 @@ -66,7 +66,7 @@ def test_func(): self.assertEqual(self.cs["test_func"], test_func) # try adding the same function again - self.cs.factory_item("test_func", "Test function but with new description")(test_func) + self.cs.add_def("test_func", "Test function but with new description")(test_func) self.assertEqual(len(self.cs), 1) self.assertEqual(self.cs.test_func, test_func) From 81f1d884baa06f698084e51b78089404115bbe25 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 26 Oct 2023 09:57:34 +0100 Subject: [PATCH 05/13] Makes ComponentStore the base class and LayerFactory inherits from it --- monai/networks/layers/factories.py | 193 +++++++++++++++++++++++++---- monai/utils/component_store.py | 10 +- 2 files changed, 177 insertions(+), 26 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index bb56b0c0c5..69e688a062 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -68,34 +68,23 @@ def use_factory(fact_args): import torch.nn as nn from monai.networks.utils import has_nvfuser_instance_norm -from monai.utils import look_up_option, optional_import +from monai.utils import ComponentStore, look_up_option, optional_import __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] -class LayerFactory: +class LayerFactory(ComponentStore): """ Factory object for creating layers, this uses given factory functions to actually produce the types or constructing callables. These functions are referred to by name and can be added at any time. """ - def __init__(self) -> None: - self.factories: dict[str, Callable] = {} - - @property - def names(self) -> tuple[str, ...]: - """ - Produces all factory names. - """ - - return tuple(self.factories) - - def add_factory_callable(self, name: str, func: Callable) -> None: + def add_factory_callable(self, name: str, func: Callable, desc: str = None) -> None: """ Add the factory function to this object under the given name. """ - - self.factories[name.upper()] = func + self.add(name.upper(), desc or func.__doc__, func) + # self.components[name.upper()] = func self.__doc__ = ( "The supported member" + ("s are: " if len(self.names) > 1 else " is: ") @@ -126,8 +115,9 @@ def get_constructor(self, factory_name: str, *args) -> Any: if not isinstance(factory_name, str): raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.") - func = look_up_option(factory_name.upper(), self.factories) - return func(*args) + component = look_up_option(factory_name.upper(), self.components) + + return component.value(*args) def __getitem__(self, args) -> Any: """ @@ -153,7 +143,7 @@ def __getattr__(self, key): as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo. """ - if key in self.factories: + if key in self.components: return key return super().__getattribute__(key) @@ -195,54 +185,111 @@ def split_args(args): # Define factories for these layer types -Dropout = LayerFactory() -Norm = LayerFactory() -Act = LayerFactory() -Conv = LayerFactory() -Pool = LayerFactory() -Pad = LayerFactory() +Dropout = LayerFactory(name="Dropout layers", description="Factory for creating dropout layers.") +Norm = LayerFactory(name="Normalization layers", description="Factory for creating normalization layers.") +Act = LayerFactory(name="Activation layers", description="Factory for creating activation layers.") +Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") +Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") +Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") @Dropout.factory_function("dropout") def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]: + """ + Dropout layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the dropout layer + + Returns: + Dropout[dim]d + """ types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d) return types[dim - 1] @Dropout.factory_function("alphadropout") def alpha_dropout_factory(_dim): + """ + Alpha dropout layer. + + Returns: + AlphaDropout + """ return nn.AlphaDropout @Norm.factory_function("instance") def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]: + """ + Instance normalization layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the instance normalization layer + + Returns: + InstanceNorm[dim]d + """ types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) return types[dim - 1] @Norm.factory_function("batch") def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]: + """ + Batch normalization layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the batch normalization layer + + Returns: + BatchNorm[dim]d + """ types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) return types[dim - 1] @Norm.factory_function("group") def group_factory(_dim) -> type[nn.GroupNorm]: + """ + Group normalization layer. + + Returns: + GroupNorm + """ return nn.GroupNorm @Norm.factory_function("layer") def layer_factory(_dim) -> type[nn.LayerNorm]: + """ + Layer normalization layer. + + Returns: + LayerNorm + """ return nn.LayerNorm @Norm.factory_function("localresponse") def local_response_factory(_dim) -> type[nn.LocalResponseNorm]: + """ + Local response normalization layer. + + Returns: + LocalResponseNorm + """ return nn.LocalResponseNorm @Norm.factory_function("syncbatch") def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]: + """ + Synchronized batch normalization layer. + + Returns: + SyncBatchNorm + """ return nn.SyncBatchNorm @@ -290,6 +337,12 @@ def instance_nvfuser_factory(dim): @Act.factory_function("swish") def swish_factory(): + """ + Swish activation layer. + + Returns: + Swish + """ from monai.networks.blocks.activation import Swish return Swish @@ -297,6 +350,12 @@ def swish_factory(): @Act.factory_function("memswish") def memswish_factory(): + """ + Memory efficient swish activation layer. + + Returns: + MemoryEfficientSwish + """ from monai.networks.blocks.activation import MemoryEfficientSwish return MemoryEfficientSwish @@ -304,6 +363,12 @@ def memswish_factory(): @Act.factory_function("mish") def mish_factory(): + """ + Mish activation layer. + + Returns: + Mish + """ from monai.networks.blocks.activation import Mish return Mish @@ -311,6 +376,12 @@ def mish_factory(): @Act.factory_function("geglu") def geglu_factory(): + """ + GEGLU activation layer. + + Returns: + GEGLU + """ from monai.networks.blocks.activation import GEGLU return GEGLU @@ -318,47 +389,119 @@ def geglu_factory(): @Conv.factory_function("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: + """ + Convolutional layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the convolutional layer + + Returns: + Conv[dim]d + """ types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) return types[dim - 1] @Conv.factory_function("convtrans") def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]: + """ + Transposed convolutional layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the transposed convolutional layer + + Returns: + ConvTranspose[dim]d + """ types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) return types[dim - 1] @Pool.factory_function("max") def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]: + """ + Max pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the max pooling layer + + Returns: + MaxPool[dim]d + """ types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d) return types[dim - 1] @Pool.factory_function("adaptivemax") def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]: + """ + Adaptive max pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the adaptive max pooling layer + + Returns: + AdaptiveMaxPool[dim]d + """ types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d) return types[dim - 1] @Pool.factory_function("avg") def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]: + """ + Average pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the average pooling layer + + Returns: + AvgPool[dim]d + """ types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d) return types[dim - 1] @Pool.factory_function("adaptiveavg") def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]: + """ + Adaptive average pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the adaptive average pooling layer + + Returns: + AdaptiveAvgPool[dim]d + """ types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d) return types[dim - 1] @Pad.factory_function("replicationpad") def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]: + """ + Replication padding layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the replication padding layer + + Returns: + ReplicationPad[dim]d + """ types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d) return types[dim - 1] @Pad.factory_function("constantpad") def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]: + """ + Constant padding layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the constant padding layer + + Returns: + ConstantPad[dim]d + """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py index 6fd8e8884f..3b75a2d370 100644 --- a/monai/utils/component_store.py +++ b/monai/utils/component_store.py @@ -76,6 +76,14 @@ def deco(func): return deco + @property + def names(self) -> tuple[str, ...]: + """ + Produces all factory names. + """ + + return tuple(self.components) + def __contains__(self, name: str) -> bool: """Returns True if the given name is stored.""" return name in self.components @@ -94,7 +102,7 @@ def __str__(self): for k, v in self.components.items(): result += f"\n* {k}:" - if hasattr(v.value, "__doc__"): + if hasattr(v.value, "__doc__") and v.value.__doc__: doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") result += f"\n{doc}\n" else: From d3669c70e300ed37bd9f0d30d90ebb9f02d40d11 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 26 Oct 2023 10:11:24 +0100 Subject: [PATCH 06/13] Adds new add_factory_class method, useful when the desired factory function only returns a single class --- monai/networks/layers/factories.py | 37 +++++++++++++++++++----------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 69e688a062..577064cc1f 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -84,7 +84,18 @@ def add_factory_callable(self, name: str, func: Callable, desc: str = None) -> N Add the factory function to this object under the given name. """ self.add(name.upper(), desc or func.__doc__, func) - # self.components[name.upper()] = func + self.__doc__ = ( + "The supported member" + + ("s are: " if len(self.names) > 1 else " is: ") + + ", ".join(f"``{name}``" for name in self.names) + + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." + ) + + def add_factory_class(self, name: str, cls: type, desc: str = None) -> None: + """ + Adds a factory function which returns the given class. + """ + self.add(name.upper(), desc or cls.__doc__, lambda: cls) self.__doc__ = ( "The supported member" + ("s are: " if len(self.names) > 1 else " is: ") @@ -321,18 +332,18 @@ def instance_nvfuser_factory(dim): return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0] -Act.add_factory_callable("elu", lambda: nn.modules.ELU) -Act.add_factory_callable("relu", lambda: nn.modules.ReLU) -Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) -Act.add_factory_callable("prelu", lambda: nn.modules.PReLU) -Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6) -Act.add_factory_callable("selu", lambda: nn.modules.SELU) -Act.add_factory_callable("celu", lambda: nn.modules.CELU) -Act.add_factory_callable("gelu", lambda: nn.modules.GELU) -Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid) -Act.add_factory_callable("tanh", lambda: nn.modules.Tanh) -Act.add_factory_callable("softmax", lambda: nn.modules.Softmax) -Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax) +Act.add_factory_class("elu", nn.modules.ELU) +Act.add_factory_class("relu", nn.modules.ReLU) +Act.add_factory_class("leakyrelu", nn.modules.LeakyReLU) +Act.add_factory_class("prelu", nn.modules.PReLU) +Act.add_factory_class("relu6", nn.modules.ReLU6) +Act.add_factory_class("selu", nn.modules.SELU) +Act.add_factory_class("celu", nn.modules.CELU) +Act.add_factory_class("gelu", nn.modules.GELU) +Act.add_factory_class("sigmoid", nn.modules.Sigmoid) +Act.add_factory_class("tanh", nn.modules.Tanh) +Act.add_factory_class("softmax", nn.modules.Softmax) +Act.add_factory_class("logsoftmax", nn.modules.LogSoftmax) @Act.factory_function("swish") From 393cdc6e938ee6ce61a2b386f9b9e8ca73bc8b37 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 26 Oct 2023 10:20:42 +0100 Subject: [PATCH 07/13] Allows lambda to take an optional unused argument --- monai/networks/layers/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 577064cc1f..3e2690f73f 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -95,7 +95,7 @@ def add_factory_class(self, name: str, cls: type, desc: str = None) -> None: """ Adds a factory function which returns the given class. """ - self.add(name.upper(), desc or cls.__doc__, lambda: cls) + self.add(name.upper(), desc or cls.__doc__, lambda x=None: cls) self.__doc__ = ( "The supported member" + ("s are: " if len(self.names) > 1 else " is: ") From 7c8363d977e42ac3d4e470efc69678911083d0c4 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 26 Oct 2023 10:24:29 +0100 Subject: [PATCH 08/13] Removes base factory class --- monai/utils/__init__.py | 1 - monai/utils/factory.py | 56 ----------------------------------------- 2 files changed, 57 deletions(-) delete mode 100644 monai/utils/factory.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index a558c38ef0..82f944ccb8 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -62,7 +62,6 @@ Weight, WSIPatchKeys, ) -from .factory import Factory from .jupyter_utils import StatusMembers, ThreadContainer from .misc import ( MAX_SEED, diff --git a/monai/utils/factory.py b/monai/utils/factory.py deleted file mode 100644 index 0d1cafde07..0000000000 --- a/monai/utils/factory.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Defines a generic factory class. -""" - -from __future__ import annotations - -from typing import Callable - - -class Factory: - """ - Baseline factory object. - """ - - def __len__(self) -> int: - """Returns the number of stored components.""" - return len(self.factories) - - def __contains__(self, name: str) -> bool: - """Returns True if the given name is stored.""" - return name in self.factories - - @property - def names(self) -> tuple[str, ...]: - """ - Produces all factory names. - """ - - return tuple(self.factories) - - def add(self, *args) -> None: - """ - Add a factory item. - """ - raise NotImplementedError - - def factory_item(self, *args) -> Callable: - """ - Decorator for adding a factory item with the given name and other associated information. - """ - - def _add(func: Callable) -> Callable: - self.add(*args, func) - return func - - return _add From bb9cf2af9606de90cb1e300ee0d8cb4ab4a6ea44 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 26 Oct 2023 15:30:50 +0100 Subject: [PATCH 09/13] Further simplify some of the factory additions --- monai/networks/layers/factories.py | 65 +++++------------------------- 1 file changed, 9 insertions(+), 56 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 3e2690f73f..cd7915a9ad 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -81,7 +81,7 @@ class LayerFactory(ComponentStore): def add_factory_callable(self, name: str, func: Callable, desc: str = None) -> None: """ - Add the factory function to this object under the given name. + Add the factory function to this object under the given name, with optional description. """ self.add(name.upper(), desc or func.__doc__, func) self.__doc__ = ( @@ -93,7 +93,7 @@ def add_factory_callable(self, name: str, func: Callable, desc: str = None) -> N def add_factory_class(self, name: str, cls: type, desc: str = None) -> None: """ - Adds a factory function which returns the given class. + Adds a factory function which returns the supplied class under the given name, with optional description. """ self.add(name.upper(), desc or cls.__doc__, lambda x=None: cls) self.__doc__ = ( @@ -195,7 +195,6 @@ def split_args(args): # Define factories for these layer types - Dropout = LayerFactory(name="Dropout layers", description="Factory for creating dropout layers.") Norm = LayerFactory(name="Normalization layers", description="Factory for creating normalization layers.") Act = LayerFactory(name="Activation layers", description="Factory for creating activation layers.") @@ -219,15 +218,7 @@ def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]: return types[dim - 1] -@Dropout.factory_function("alphadropout") -def alpha_dropout_factory(_dim): - """ - Alpha dropout layer. - - Returns: - AlphaDropout - """ - return nn.AlphaDropout +Dropout.add_factory_class("alphadropout", nn.AlphaDropout) @Norm.factory_function("instance") @@ -260,50 +251,6 @@ def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNo return types[dim - 1] -@Norm.factory_function("group") -def group_factory(_dim) -> type[nn.GroupNorm]: - """ - Group normalization layer. - - Returns: - GroupNorm - """ - return nn.GroupNorm - - -@Norm.factory_function("layer") -def layer_factory(_dim) -> type[nn.LayerNorm]: - """ - Layer normalization layer. - - Returns: - LayerNorm - """ - return nn.LayerNorm - - -@Norm.factory_function("localresponse") -def local_response_factory(_dim) -> type[nn.LocalResponseNorm]: - """ - Local response normalization layer. - - Returns: - LocalResponseNorm - """ - return nn.LocalResponseNorm - - -@Norm.factory_function("syncbatch") -def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]: - """ - Synchronized batch normalization layer. - - Returns: - SyncBatchNorm - """ - return nn.SyncBatchNorm - - @Norm.factory_function("instance_nvfuser") def instance_nvfuser_factory(dim): """ @@ -332,6 +279,12 @@ def instance_nvfuser_factory(dim): return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0] +Norm.add_factory_class("group", nn.GroupNorm) +Norm.add_factory_class("layer", nn.LayerNorm) +Norm.add_factory_class("localresponse", nn.LocalResponseNorm) +Norm.add_factory_class("syncbatch", nn.SyncBatchNorm) + + Act.add_factory_class("elu", nn.modules.ELU) Act.add_factory_class("relu", nn.modules.ReLU) Act.add_factory_class("leakyrelu", nn.modules.LeakyReLU) From dfd0cc89f53eb36fa2b14f8ef781556b2620baa1 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 27 Oct 2023 11:22:56 +0100 Subject: [PATCH 10/13] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: dfaa6a01412d860f6bfed707fc2e56639228edcb I, Mark Graham , hereby add my Signed-off-by to this commit: 99ecb5a4890b9e951d1ad495897e6f408a29ab83 I, Mark Graham , hereby add my Signed-off-by to this commit: 8ebd0b6656f14fda0222a95766109601f799b5db I, Mark Graham , hereby add my Signed-off-by to this commit: 81f1d884baa06f698084e51b78089404115bbe25 I, Mark Graham , hereby add my Signed-off-by to this commit: d3669c70e300ed37bd9f0d30d90ebb9f02d40d11 I, Mark Graham , hereby add my Signed-off-by to this commit: 393cdc6e938ee6ce61a2b386f9b9e8ca73bc8b37 I, Mark Graham , hereby add my Signed-off-by to this commit: 7c8363d977e42ac3d4e470efc69678911083d0c4 I, Mark Graham , hereby add my Signed-off-by to this commit: bb9cf2af9606de90cb1e300ee0d8cb4ab4a6ea44 Signed-off-by: Mark Graham --- monai/networks/layers/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index cd7915a9ad..ee774739c0 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -62,7 +62,7 @@ def use_factory(fact_args): from __future__ import annotations import warnings -from collections.abc import Callable +from collections.abc import Callable from typing import Any import torch.nn as nn From afeb4172d07c191609ece3dcd2cef1541a699b32 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 27 Oct 2023 11:41:51 +0100 Subject: [PATCH 11/13] Fixes mypy errors Signed-off-by: Mark Graham --- monai/networks/layers/factories.py | 12 +++++++----- monai/utils/component_store.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index ee774739c0..8537d20d1d 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -62,7 +62,7 @@ def use_factory(fact_args): from __future__ import annotations import warnings -from collections.abc import Callable +from collections.abc import Callable from typing import Any import torch.nn as nn @@ -79,11 +79,12 @@ class LayerFactory(ComponentStore): callables. These functions are referred to by name and can be added at any time. """ - def add_factory_callable(self, name: str, func: Callable, desc: str = None) -> None: + def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None: """ Add the factory function to this object under the given name, with optional description. """ - self.add(name.upper(), desc or func.__doc__, func) + description: str = desc or func.__doc__ or "" + self.add(name.upper(), description, func) self.__doc__ = ( "The supported member" + ("s are: " if len(self.names) > 1 else " is: ") @@ -91,11 +92,12 @@ def add_factory_callable(self, name: str, func: Callable, desc: str = None) -> N + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." ) - def add_factory_class(self, name: str, cls: type, desc: str = None) -> None: + def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None: """ Adds a factory function which returns the supplied class under the given name, with optional description. """ - self.add(name.upper(), desc or cls.__doc__, lambda x=None: cls) + description: str = desc or cls.__doc__ or "" + self.add(name.upper(), description, lambda x=None: cls) self.__doc__ = ( "The supported member" + ("s are: " if len(self.names) > 1 else " is: ") diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py index 3b75a2d370..d1e71eaebf 100644 --- a/monai/utils/component_store.py +++ b/monai/utils/component_store.py @@ -50,10 +50,10 @@ def _my_func(a, b): """ - _Component = namedtuple("Component", ("description", "value")) # internal value pair + _Component = namedtuple("_Component", ("description", "value")) # internal value pair def __init__(self, name: str, description: str) -> None: - self.components: dict[str, self._Component] = {} + self.components: dict[str, ComponentStore._Component] = {} self.name: str = name self.description: str = description From 8a51d799f87d22dff99e3cdc0280bc1599e9f06f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 27 Oct 2023 08:01:21 -0600 Subject: [PATCH 12/13] Update monai/networks/layers/factories.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/networks/layers/factories.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 8537d20d1d..e6645d6f93 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -79,7 +79,8 @@ class LayerFactory(ComponentStore): callables. These functions are referred to by name and can be added at any time. """ - def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None: + def add_factory_callable(self, func: Callable, name: str | None = None, desc: str | None = None) -> None: + name = name if name is not None else getattr(func, "__name__", "???") """ Add the factory function to this object under the given name, with optional description. """ From 7e6fa86d05e532eaea4d5951fb80674c076feac3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 27 Oct 2023 15:45:40 +0100 Subject: [PATCH 13/13] Updates Signed-off-by: Mark Graham --- monai/networks/layers/factories.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index e6645d6f93..38ee68cbee 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -79,32 +79,29 @@ class LayerFactory(ComponentStore): callables. These functions are referred to by name and can be added at any time. """ - def add_factory_callable(self, func: Callable, name: str | None = None, desc: str | None = None) -> None: - name = name if name is not None else getattr(func, "__name__", "???") + def __init__(self, name: str, description: str) -> None: + super().__init__(name, description) + self.__doc__ = ( + f"Layer Factory '{name}': {description}\n".strip() + + "\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." + + "\n\nThe supported members are:" + ) + + def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None: """ Add the factory function to this object under the given name, with optional description. """ description: str = desc or func.__doc__ or "" self.add(name.upper(), description, func) - self.__doc__ = ( - "The supported member" - + ("s are: " if len(self.names) > 1 else " is: ") - + ", ".join(f"``{name}``" for name in self.names) - + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." - ) + # append name to the docstring + assert self.__doc__ is not None + self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``" def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None: """ Adds a factory function which returns the supplied class under the given name, with optional description. """ - description: str = desc or cls.__doc__ or "" - self.add(name.upper(), description, lambda x=None: cls) - self.__doc__ = ( - "The supported member" - + ("s are: " if len(self.names) > 1 else " is: ") - + ", ".join(f"``{name}``" for name in self.names) - + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." - ) + self.add_factory_callable(name, lambda x=None: cls, desc) def factory_function(self, name: str) -> Callable: """