From 5ea2aa8f4f42945e3d56c317485253583ec3504f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 30 Mar 2023 14:59:00 +0100 Subject: [PATCH 01/10] Improve Compose encapsulation (#6224) Fixes #6223 . ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Ben Murray Signed-off-by: a-parida12 --- monai/data/dataset.py | 98 ++++++++-------- monai/transforms/compose.py | 217 +++++++++++++++++++++++++++++++++--- tests/test_compose.py | 144 +++++++++++++++++++++--- 3 files changed, 378 insertions(+), 81 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 5ef8d7e903..7df19d88d3 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -39,7 +39,6 @@ Compose, Randomizable, RandomizableTrait, - ThreadUnsafe, Transform, apply_transform, convert_to_contiguous, @@ -209,6 +208,11 @@ class PersistentDataset(Dataset): not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory. + Lazy Resampling: + If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to + its documentation to familiarize yourself with the interaction between `PersistentDataset` and + lazy resampling. + """ def __init__( @@ -316,15 +320,15 @@ def _pre_transform(self, item_transformed): random transform object """ - for _transform in self.transform.transforms: - # execute all the deterministic transforms - if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): - break - # this is to be consistent with CacheDataset even though it's not in a multi-thread situation. - _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform) - item_transformed = apply_transform(_xform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) + if not isinstance(self.transform, Compose): + raise ValueError("transform must be an instance of monai.transforms.Compose.") + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + + item_transformed = self.transform(item_transformed, end=first_random, threading=True) + if self.reset_ops_id: reset_ops_id(item_transformed) return item_transformed @@ -342,17 +346,12 @@ def _post_transform(self, item_transformed): """ if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - start_post_randomize_run = False - for _transform in self.transform.transforms: - if ( - start_post_randomize_run - or isinstance(_transform, RandomizableTrait) - or not isinstance(_transform, Transform) - ): - start_post_randomize_run = True - item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform) - item_transformed = apply_transform(_transform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + if first_random is not None: + item_transformed = self.transform(item_transformed, start=first_random) return item_transformed def _cachecheck(self, item_transformed): @@ -496,13 +495,9 @@ def _pre_transform(self, item_transformed): """ if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for i, _transform in enumerate(self.transform.transforms): - if i == self.cache_n_trans: - break - _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform) - item_transformed = apply_transform(_xform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) + + item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True) + reset_ops_id(item_transformed) return item_transformed @@ -518,12 +513,8 @@ def _post_transform(self, item_transformed): """ if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for i, _transform in enumerate(self.transform.transforms): - if i >= self.cache_n_trans: - item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed) - item_transformed = apply_transform(_transform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) - return item_transformed + + return self.transform(item_transformed, start=self.cache_n_trans) class LMDBDataset(PersistentDataset): @@ -748,6 +739,11 @@ class CacheDataset(Dataset): So to debug or verify the program before real training, users can set `cache_rate=0.0` or `cache_num=0` to temporarily skip caching. + Lazy Resampling: + If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to + its documentation to familiarize yourself with the interaction between `CacheDataset` and + lazy resampling. + """ def __init__( @@ -887,14 +883,12 @@ def _load_cache_item(self, idx: int): idx: the index of the input data sequence. """ item = self.data[idx] - for _transform in self.transform.transforms: - # execute all the deterministic transforms - if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): - break - _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - item = self.transform.evaluate_with_overrides(item, _xform) - item = apply_transform(_xform, item) - item = self.transform.evaluate_with_overrides(item, None) + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + item = self.transform(item, end=first_random, threading=True) + if self.as_contiguous: item = convert_to_contiguous(item, memory_format=torch.contiguous_format) return item @@ -921,19 +915,16 @@ def _transform(self, index: int): data = self._cache[cache_index] = self._load_cache_item(cache_index) # load data from cache and execute from the first random transform - start_run = False if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for _transform in self.transform.transforms: - if start_run or isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): - # only need to deep copy data on first non-deterministic transform - if not start_run: - start_run = True - if self.copy_cache: - data = deepcopy(data) - data = self.transform.evaluate_with_overrides(data, _transform) - data = apply_transform(_transform, data) - data = self.transform.evaluate_with_overrides(data, None) + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + if first_random is not None: + data = deepcopy(data) if self.copy_cache is True else data + data = self.transform(data, start=first_random) + return data @@ -1008,7 +999,6 @@ class SmartCacheDataset(Randomizable, CacheDataset): as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic. runtime_cache: Default to `False`, other options are not implemented yet. - """ def __init__( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 6cdd1b3d55..8a8518c92b 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -16,6 +16,7 @@ import warnings from collections.abc import Callable, Mapping, Sequence +from copy import deepcopy from typing import Any import numpy as np @@ -23,7 +24,9 @@ import monai import monai.transforms as mt from monai.apps.utils import get_logger +from monai.config import NdarrayOrTensor from monai.transforms.inverse import InvertibleTransform +from monai.transforms.traits import ThreadUnsafe # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 @@ -115,6 +118,91 @@ def evaluate_with_overrides( return data +def execute_compose( + data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], + transforms: Sequence[Any], + map_items: bool = True, + unpack_items: bool = False, + start: int = 0, + end: int | None = None, + lazy_evaluation: bool = False, + overrides: dict | None = None, + override_keys: Sequence[str] | None = None, + threading: bool = False, + log_stats: bool = False, + verbose: bool = False, +) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]: + """ + ``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence + of transforms. As well as being used by Compose, it can be used by subclasses of + Compose and by code that doesn't have a Compose instance but needs to execute a + sequence of transforms is if it were executed by Compose. It should only be used directly + when it is not possible to use ``Compose.__call__`` to achieve the same goal. + Args: + data: a tensor-like object to be transformed + transforms: a sequence of transforms to be carried out + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + defaults to `False`. + start: the index of the first transform to be executed. If not set, this defaults to 0 + end: the index after the last transform to be exectued. If set, the transform at index-1 + is the last transform that is executed. If this is not set, it defaults to len(transforms) + lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + carried out on a transform by transform basis. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. + A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of + the pending operations and make the primary data up-to-date. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation + is True. If lazy_evaluation is False they are ignored. + currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, + please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. + override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If + ``overrides`` is set, ``override_keys`` must also be set. + threading: whether executing is happening in a threaded environment. If set, copies are made + of transforms that have the ``RandomizedTrait`` interface. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other metadata, log the values directly. default to `False`. + verbose: whether to print debugging info when lazy_evaluation=True. + + Returns: + A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running + `data`` through the sequence of ``transforms``. + """ + end_ = len(transforms) if end is None else end + if start is None: + raise ValueError(f"'start' ({start}) cannot be None") + if start > end_: + raise ValueError(f"'start' ({start}) must be less than 'end' ({end_})") + if end_ > len(transforms): + raise ValueError(f"'end' ({end_}) must be less than or equal to the transform count ({len(transforms)}") + + # no-op if the range is empty + if start == end: + return data + + for _transform in transforms[start:end]: + if threading: + _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform + data = evaluate_with_overrides( + data, + _transform, + lazy_evaluation=lazy_evaluation, + overrides=overrides, + override_keys=override_keys, + verbose=verbose, + ) + data = apply_transform(_transform, data, map_items, unpack_items, log_stats) + data = evaluate_with_overrides( + data, None, lazy_evaluation=lazy_evaluation, overrides=overrides, override_keys=override_keys, verbose=verbose + ) + return data + + class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of callables together in @@ -183,6 +271,37 @@ class Compose(Randomizable, InvertibleTransform): calls your pre-processing functions taking into account that not all of them are called on the labels. + Lazy resampling: + Lazy resampling is an experimental feature introduced in 1.2. Its purpose is + to reduce the number of resample operations that must be carried out when executing + a pipeline of transforms. This can provide significant performance improvements in + terms of pipeline executing speed and memory usage, but can also significantly + reduce the loss of information that occurs when performing a number of spatial + resamples in succession. + + Lazy resampling can be thought of as acting in a similar fashion to the `Affine` & `RandAffine` + transforms, in that they allow several spatial transform operations can be specified and carried out with + a single resample step. Unlike these transforms, however, lazy resampling can operate on any set of + transforms specified in any ordering. The user is free to mix monai transforms with transforms from other + libraries; lazy resampling will determine the minimum number of resample steps required in order to + execute the pipeline. + + Lazy resampling works with monai `Dataset` classes that provide caching and persistence. However, if you + are implementing your own caching dataset implementation and wish to make use of lazy resampling, you + should ensure that you fully execute the part of the pipeline that generates the data to be cached + before caching it. This is quite simply done however, as shown by the following example. + + Example: + # run the part of the pipeline that needs to be cached + data = self.transform(data, end=self.post_cache_index) + + # --- + + # fetch the data from the cache and run the rest of the pipeline + data = get_data_from_my_cache(data) + data = self.transform(data, start=self.post_cache_index) + + Args: transforms: sequence of callables. map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. @@ -258,6 +377,41 @@ def randomize(self, data: Any | None = None) -> None: f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning ) + def get_index_of_first(self, predicate): + """ + get_index_of_first takes a ``predicate`` and returns the index of the first transform that + satisfies the predicate (ie. makes the predicate return True). If it is unable to find + a transform that satisfies the ``predicate``, it returns None. + + Example: + c = Compose([Flip(...), Rotate90(...), Zoom(...), RandRotate(...), Resize(...)]) + + print(c.get_index_of_first(lambda t: isinstance(t, RandomTrait))) + >>> 3 + print(c.get_index_of_first(lambda t: isinstance(t, Compose))) + >>> None + + Note: + This is only performed on the transforms directly held by this instance. If this + instance has nested ``Compose`` transforms or other transforms that contain transforms, + it does not iterate into them. + + + Args: + predicate: a callable that takes a single argument and returns a bool. When called + it is passed a transform from the sequence of transforms contained by this compose + instance. + + Returns: + The index of the first transform in the sequence for which ``predicate`` returns + True. None if no transform satisfies the ``predicate`` + + """ + for i in range(len(self.transforms)): + if predicate(self.transforms[i]): + return i + return None + def flatten(self): """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. @@ -293,12 +447,21 @@ def evaluate_with_overrides(self, input_, upcoming_xform): verbose=self.verbose, ) - def __call__(self, input_): - for _transform in self.transforms: - input_ = self.evaluate_with_overrides(input_, _transform) - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) - input_ = self.evaluate_with_overrides(input_, None) - return input_ + def __call__(self, input_, start=0, end=None, threading=False): + return execute_compose( + input_, + self.transforms, + start=start, + end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy_evaluation=self.lazy_evaluation, # type: ignore + overrides=self.overrides, + override_keys=self.override_keys, + threading=threading, + log_stats=self.log_stats, + verbose=self.verbose, + ) # type: ignore def inverse(self, data): invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] @@ -397,12 +560,23 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, data): + def __call__(self, data, start=0, end=None, threading=False): if len(self.transforms) == 0: return data + index = self.R.multinomial(1, self.weights).argmax() _transform = self.transforms[index] - data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats) + + data = execute_compose( + data, + [_transform], + map_items=self.map_items, + unpack_items=self.unpack_items, + start=start, + end=end, + threading=threading, + ) + # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"index": index}) @@ -481,14 +655,22 @@ def __init__( transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose ) - def __call__(self, input_): + def __call__(self, input_, start=0, end=None, threading=False): if len(self.transforms) == 0: return input_ num = len(self.transforms) applied_order = self.R.permutation(range(num)) - for index in applied_order: - input_ = apply_transform(self.transforms[index], input_, self.map_items, self.unpack_items, self.log_stats) + input_ = execute_compose( + input_, + [self.transforms[ind] for ind in applied_order], + map_items=self.map_items, + unpack_items=self.unpack_items, + start=start, + end=end, + threading=threading, + ) + # if the data is a mapping (dictionary), append the RandomOrder transform to the end if isinstance(input_, monai.data.MetaTensor): self.push_transform(input_, extra_info={"applied_order": applied_order}) @@ -618,15 +800,22 @@ def _normalize_probabilities(self, weights): return ensure_tuple(list(weights)) - def __call__(self, data): + def __call__(self, data, start=0, end=None, threading=False): if len(self.transforms) == 0: return data sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1) applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist() - for i in applied_order: - data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats) + data = execute_compose( + data, + [self.transforms[a] for a in applied_order], + map_items=self.map_items, + unpack_items=self.unpack_items, + start=start, + end=end, + threading=threading, + ) if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"applied_order": applied_order}) elif isinstance(data, Mapping): diff --git a/tests/test_compose.py b/tests/test_compose.py index ddb7ce25d8..47869b02aa 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -13,9 +13,15 @@ import sys import unittest +from copy import deepcopy + +import numpy as np +import torch +from parameterized import parameterized from monai.data import DataLoader, Dataset -from monai.transforms import AddChannel, Compose +from monai.transforms import AddChannel, Compose, Flip, NormalizeIntensity, Rotate, Rotate90, Rotated, Zoom +from monai.transforms.compose import execute_compose from monai.transforms.transform import Randomizable from monai.utils import set_determinism @@ -56,8 +62,12 @@ def b(d): d["b"] += 1 return d - c = Compose([a, b, a, b, a]) - self.assertDictEqual(c({"a": 0, "b": 0}), {"a": 3, "b": 2}) + transforms = [a, b, a, b, a] + data = {"a": 0, "b": 0} + expected = {"a": 3, "b": 2} + + self.assertDictEqual(Compose(transforms)(data), expected) + self.assertDictEqual(execute_compose(data, transforms), expected) def test_list_dict_compose(self): def a(d): # transform to handle dict data @@ -76,10 +86,15 @@ def c(d): # transform to handle dict data d["c"] += 1 return d - transforms = Compose([a, a, b, c, c]) - value = transforms({"a": 0, "b": 0, "c": 0}) + transforms = [a, a, b, c, c] + data = {"a": 0, "b": 0, "c": 0} + expected = {"a": 2, "b": 1, "c": 2} + value = Compose(transforms)(data) + for item in value: + self.assertDictEqual(item, expected) + value = execute_compose(data, transforms) for item in value: - self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + self.assertDictEqual(item, expected) def test_non_dict_compose_with_unpack(self): def a(i, i2): @@ -88,8 +103,11 @@ def a(i, i2): def b(i, i2): return i + "b", i2 + "b2" - c = Compose([a, b, a, b], map_items=False, unpack_items=True) - self.assertEqual(c(("", "")), ("abab", "a2b2a2b2")) + transforms = [a, b, a, b] + data = ("", "") + expected = ("abab", "a2b2a2b2") + self.assertEqual(Compose(transforms, map_items=False, unpack_items=True)(data), expected) + self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected) def test_list_non_dict_compose_with_unpack(self): def a(i, i2): @@ -98,8 +116,11 @@ def a(i, i2): def b(i, i2): return i + "b", i2 + "b2" - c = Compose([a, b, a, b], unpack_items=True) - self.assertEqual(c([("", ""), ("t", "t")]), [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")]) + transforms = [a, b, a, b] + data = [("", ""), ("t", "t")] + expected = [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")] + self.assertEqual(Compose(transforms, unpack_items=True)(data), expected) + self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) def test_list_dict_compose_no_map(self): def a(d): # transform to handle dict data @@ -119,10 +140,15 @@ def c(d): # transform to handle dict data di["c"] += 1 return d - transforms = Compose([a, a, b, c, c], map_items=False) - value = transforms({"a": 0, "b": 0, "c": 0}) + transforms = [a, a, b, c, c] + data = {"a": 0, "b": 0, "c": 0} + expected = {"a": 2, "b": 1, "c": 2} + value = Compose(transforms, map_items=False)(data) for item in value: - self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + self.assertDictEqual(item, expected) + value = execute_compose(data, transforms, map_items=False) + for item in value: + self.assertDictEqual(item, expected) def test_random_compose(self): class _Acc(Randomizable): @@ -220,5 +246,97 @@ def test_backwards_compatible_imports(self): from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 +TEST_COMPOSE_EXECUTE_TEST_CASES = [ + [None, tuple()], + [None, (Rotate(np.pi / 8),)], + [None, (Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())], + [("a",), (Rotated(("a",), np.pi / 8),)], +] + + +class TestComposeExecute(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_equivalence(self, keys, pipeline): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(24 * 32).reshape(24, 32)), axis=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(24 * 32)).reshape(24, 32) + i_k * 768, axis=0) + + expected = Compose(deepcopy(pipeline))(data) + + for cutoff in range(len(pipeline)): + c = Compose(deepcopy(pipeline)) + actual = c(c(data, end=cutoff), start=cutoff) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertTrue(torch.allclose(expected[k], actual[k])) + else: + self.assertTrue(torch.allclose(expected, actual)) + + p = deepcopy(pipeline) + actual = execute_compose(execute_compose(data, p, start=0, end=cutoff), p, start=cutoff) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertTrue(torch.allclose(expected[k], actual[k])) + else: + self.assertTrue(torch.allclose(expected, actual)) + + +class TestOps: + @staticmethod + def concat(value): + def _inner(data): + return data + value + + return _inner + + @staticmethod + def concatd(value): + def _inner(data): + return {k: v + value for k, v in data.items()} + + return _inner + + @staticmethod + def concata(value): + def _inner(data1, data2): + return data1 + value, data2 + value + + return _inner + + +TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES = [ + [{}, ("",), (TestOps.concat("a"), TestOps.concat("b"))], + [{"unpack_items": True}, ("x", "y"), (TestOps.concat("a"), TestOps.concat("b"))], + [{"map_items": False}, {"x": "1", "y": "2"}, (TestOps.concatd("a"), TestOps.concatd("b"))], + [{"unpack_items": True, "map_items": False}, ("x", "y"), (TestOps.concata("a"), TestOps.concata("b"))], +] + + +class TestComposeExecuteWithFlags(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES) + def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): + expected = Compose(pipeline, **flags)(data) + + for cutoff in range(len(pipeline)): + c = Compose(deepcopy(pipeline), **flags) + actual = c(c(data, end=cutoff), start=cutoff) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertEqual(expected[k], actual[k]) + else: + self.assertTrue(expected, actual) + + p = deepcopy(pipeline) + actual = execute_compose(execute_compose(data, p, start=0, end=cutoff, **flags), p, start=cutoff, **flags) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertTrue(expected[k], actual[k]) + else: + self.assertTrue(expected, actual) + + if __name__ == "__main__": unittest.main() From 1f294b24612fab34d6f27de08a876519a35769cd Mon Sep 17 00:00:00 2001 From: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Date: Thu, 30 Mar 2023 23:56:03 +0800 Subject: [PATCH 02/10] upgrade pytorch to 23.03 (#6256) Fixes #6255 ### Description upgrade pytorch to 23.03 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YanxuanLiu Signed-off-by: a-parida12 --- .github/workflows/cron.yml | 8 ++++---- .github/workflows/pythonapp-gpu.yml | 4 ++-- Dockerfile | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 88f4522b5b..c1015cd541 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -30,7 +30,7 @@ jobs: base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - environment: PTLATEST+CUDA118 pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118" - base: "nvcr.io/nvidia/pytorch:23.02-py3" # CUDA 11.8 + base: "nvcr.io/nvidia/pytorch:23.03-py3" # CUDA 11.8 container: image: ${{ matrix.base }} options: "--gpus all" @@ -76,7 +76,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:22.09", "pytorch:22.11", "pytorch:23.02"] + container: ["pytorch:22.09", "pytorch:22.11", "pytorch:23.03"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -121,7 +121,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:22.09", "pytorch:22.11", "pytorch:23.02"] + container: ["pytorch:22.09", "pytorch:22.11", "pytorch:23.03"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -221,7 +221,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:23.02-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:23.03-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index acbd6c648d..164df5efcc 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -42,9 +42,9 @@ jobs: pytorch: "torch==1.13.1 torchvision==0.14.1" base: "nvcr.io/nvidia/cuda:11.6.1-devel-ubuntu18.04" - environment: PT114+CUDA120DOCKER - # 23.02: 1.14.0a0+44dac51 + # 23.03: 2.0.0a0+1767026 pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error - base: "nvcr.io/nvidia/pytorch:23.02-py3" + base: "nvcr.io/nvidia/pytorch:23.03-py3" container: image: ${{ matrix.base }} options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6 diff --git a/Dockerfile b/Dockerfile index 9bc0eee70c..653dd1571c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.02-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.03-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" From 84ffd47815ab50ecc620ec3adfd01e2b278a053d Mon Sep 17 00:00:00 2001 From: a-parida12 Date: Sun, 2 Apr 2023 21:44:36 -0400 Subject: [PATCH 03/10] feat(SABBlock): store attn matrix Signed-off-by: a-parida12 --- monai/networks/blocks/selfattention.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 519c8c7728..4ead1327ea 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -25,13 +25,14 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None: + def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn:bool = False) -> None: """ Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: bias term for the qkv linear layer. + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -52,11 +53,16 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 + self.save_attn = save_attn def forward(self, x): output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) From 3f466f715fb960db881aa805d22cdd4da7ddcdb3 Mon Sep 17 00:00:00 2001 From: a-parida12 Date: Sun, 2 Apr 2023 21:48:25 -0400 Subject: [PATCH 04/10] fix(TransformerBlock): allow to store SA Signed-off-by: a-parida12 --- monai/networks/blocks/transformerblock.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 3a4b507d69..84bf28b4ee 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -24,7 +24,7 @@ class TransformerBlock(nn.Module): """ def __init__( - self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False + self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False ) -> None: """ Args: @@ -32,7 +32,8 @@ def __init__( mlp_dim: dimension of feedforward layer. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer + qkv_bias: apply bias term for the qkv linear layer. + save_attn: to make accessible the attention matrix post training. """ @@ -46,7 +47,7 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias) + self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size) def forward(self, x): From 1763028f890ff2b06f078d7954a2f722e23811f6 Mon Sep 17 00:00:00 2001 From: a-parida12 Date: Sun, 2 Apr 2023 22:12:01 -0400 Subject: [PATCH 05/10] ci(SABlock): test for attn_matrix Signed-off-by: a-parida12 --- tests/test_selfattention.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 926ef7da55..3087a37607 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -52,6 +52,22 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + def test_access_attn_matrix(self): + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + # be able to access the matrix + no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) + with self.assertRaises(AttributeError): + no_matrix_acess_blk(torch.randn(input_shape)) + no_matrix_acess_blk.att_mat + # be not able to acess the attention matrix + matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + + if __name__ == "__main__": unittest.main() From 9f726fe765540ca3f1f2ba094b72f82bc1b742f7 Mon Sep 17 00:00:00 2001 From: a-parida12 Date: Sun, 2 Apr 2023 22:23:36 -0400 Subject: [PATCH 06/10] formatting improved Signed-off-by: a-parida12 --- tests/test_selfattention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 3087a37607..7ff830d51b 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -53,15 +53,18 @@ def test_ill_arg(self): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) def test_access_attn_matrix(self): + # input format hidden_size = 128 num_heads = 2 dropout_rate = 0 input_shape = (2, 256, hidden_size) + # be able to access the matrix no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) with self.assertRaises(AttributeError): no_matrix_acess_blk(torch.randn(input_shape)) no_matrix_acess_blk.att_mat + # be not able to acess the attention matrix matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True) matrix_acess_blk(torch.randn(input_shape)) From 4444b197e2202648d9f7a66a72a21881dc45fd06 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Apr 2023 02:20:44 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: a-parida12 --- monai/networks/blocks/selfattention.py | 2 +- tests/test_selfattention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 4ead1327ea..20919e8cb3 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -60,7 +60,7 @@ def forward(self, x): q, k, v = output[0], output[1], output[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) if self.save_attn: - # no gradients and new tensor; + # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html self.att_mat = att_mat.detach() att_mat = self.drop_weights(att_mat) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 7ff830d51b..9f79521248 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -55,7 +55,7 @@ def test_ill_arg(self): def test_access_attn_matrix(self): # input format hidden_size = 128 - num_heads = 2 + num_heads = 2 dropout_rate = 0 input_shape = (2, 256, hidden_size) From 37af689dfc54a3d677ae855fbcc2a00f74db87de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Apr 2023 02:36:39 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_selfattention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 9f79521248..bad8793c6a 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -53,18 +53,18 @@ def test_ill_arg(self): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) def test_access_attn_matrix(self): - # input format + # input format hidden_size = 128 num_heads = 2 dropout_rate = 0 input_shape = (2, 256, hidden_size) - + # be able to access the matrix no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate) with self.assertRaises(AttributeError): no_matrix_acess_blk(torch.randn(input_shape)) no_matrix_acess_blk.att_mat - + # be not able to acess the attention matrix matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True) matrix_acess_blk(torch.randn(input_shape)) From 86fdc33ae92e68cc78ed3a3374c6d9cabfe567c5 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 3 Apr 2023 17:32:49 +0000 Subject: [PATCH 09/10] [MONAI] code formatting Signed-off-by: monai-bot --- monai/networks/blocks/selfattention.py | 9 ++++++++- monai/networks/blocks/transformerblock.py | 8 +++++++- tests/test_selfattention.py | 5 +++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 20919e8cb3..03a7bc7e08 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -25,7 +25,14 @@ class SABlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn:bool = False) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + save_attn: bool = False, + ) -> None: """ Args: hidden_size (int): dimension of hidden layer. diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 84bf28b4ee..30f2c2756a 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -24,7 +24,13 @@ class TransformerBlock(nn.Module): """ def __init__( - self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + save_attn: bool = False, ) -> None: """ Args: diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index bad8793c6a..a67f54f704 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -66,11 +66,12 @@ def test_access_attn_matrix(self): no_matrix_acess_blk.att_mat # be not able to acess the attention matrix - matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True) + matrix_acess_blk = SABlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) - if __name__ == "__main__": unittest.main() From 5718c2cee63648b124b2327b5e2108f44013f56d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 3 Apr 2023 19:20:23 +0100 Subject: [PATCH 10/10] fixes mypy error Signed-off-by: Wenqi Li --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 03bffbb1e8..43086964ae 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -929,7 +929,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): for name in filenames: img = nib.load(name, **kwargs_) img = correct_nifti_header_if_necessary(img) - img_.append(img) + img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] def get_data(self, img) -> tuple[np.ndarray, dict]: