From 79719fabd4dc024340dfc98326b9574fa4ff5e1d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Fri, 3 Nov 2023 21:56:51 +0100
Subject: [PATCH 01/88] mixup, cutmix and cutout
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/regularization/__init__.py | 14 ++
monai/regularization/mixup.py | 214 +++++++++++++++++++++++++++++++
tests/test_regularization.py | 86 +++++++++++++
3 files changed, 314 insertions(+)
create mode 100644 monai/regularization/__init__.py
create mode 100644 monai/regularization/mixup.py
create mode 100644 tests/test_regularization.py
diff --git a/monai/regularization/__init__.py b/monai/regularization/__init__.py
new file mode 100644
index 0000000000..30455df3bd
--- /dev/null
+++ b/monai/regularization/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .mixup import MixUp, MixUpd, CutMix, CutMixd
+
+__all__ = ["MixUp", "MixUpd", "CutMix", "CutMixd"]
diff --git a/monai/regularization/mixup.py b/monai/regularization/mixup.py
new file mode 100644
index 0000000000..ae974a94ef
--- /dev/null
+++ b/monai/regularization/mixup.py
@@ -0,0 +1,214 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+from collections.abc import Sequence
+from typing import Tuple
+from monai.config import KeysCollection
+import torch
+from monai.transforms import Transform, MapTransform
+from monai.utils.misc import ensure_tuple
+from math import sqrt, ceil
+
+
+class Mixer(Transform):
+ def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
+ super().__init__()
+ if alpha <= 0:
+ raise ValueError(f"Expected positive number, but got {alpha = }")
+ self._sampler = torch.distributions.beta.Beta(alpha, alpha)
+ self.batch_size = batch_size
+
+ def sample_params(self):
+ """
+ Sometimes you need may to apply the same transform to different tensors.
+ The idea is to get a sample and then apply it with apply_mixup() as often
+ as needed.
+ """
+ return self._sampler.sample((self.batch_size,)), torch.randperm(self.batch_size)
+
+ @classmethod
+ @abstractmethod
+ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
+ raise NotImplementedError()
+
+
+class MixUp(Mixer):
+ """MixUp as described in:
+ Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
+ mixup: Beyond Empirical Risk Minimization, ICLR 2018
+ """
+
+ def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
+ super().__init__(batch_size, alpha)
+
+ @classmethod
+ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
+ weight, perm = params
+ nsamples, *dims = data.shape
+ if len(weight) != nsamples:
+ raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
+
+ if len(dims) not in [3, 4]:
+ raise ValueError("Unexpected number of dimensions")
+
+ mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
+ return mixweight * data + (1 - mixweight) * data[perm, ...]
+
+ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
+ if labels is None:
+ return self.apply(self.sample_params(), data)
+
+ params = self.sample_params()
+ return self.apply(params, data), self.apply(params, labels)
+
+
+class MixUpd(MapTransform):
+ """MixUp as described in:
+ Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
+ mixup: Beyond Empirical Risk Minimization, ICLR 2018
+
+ Notice that the mixup transformation will be the same for all entries
+ for consistency, i.e. images and labels must be applied the same augmenation.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ batch_size: int,
+ alpha: float = 1.0,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mixup = MixUp(batch_size, alpha)
+
+ def __call__(self, data):
+ result = dict(data)
+ params = self.mixup.sample_params()
+ for k in self.keys:
+ result[k] = self.mixup.apply(params, data[k])
+ return result
+
+
+class CutMix(Mixer):
+ """CutMix augmentation as described in:
+ Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
+ ICCV 2019
+ """
+
+ def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
+ super().__init__(batch_size, alpha)
+
+ @classmethod
+ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
+ weights, perm = params
+ nsamples, _, *dims = data.shape
+ if len(weights) != nsamples:
+ raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
+
+ mask = torch.ones_like(data)
+ for s, weight in enumerate(weights):
+ coords = [torch.randint(0, d, size=(1,)) for d in dims]
+ lengths = [d * sqrt(1 - weight) for d in dims]
+ idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
+ mask[s][idx] = 0
+
+ return mask * data + (1 - mask) * data[perm, ...]
+
+ @classmethod
+ def apply_on_labels(cls, params: Tuple[torch.Tensor, torch.Tensor], labels: torch.Tensor):
+ return MixUp.apply(params, labels)
+
+ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
+ params = self.sample_params()
+ augmented = self.apply(params, data)
+ return (augmented, MixUp.apply(params, labels)) if labels is not None else augmented
+
+
+class CutMixd(MapTransform):
+ """CutMix augmentation as described in:
+ Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
+ ICCV 2019
+
+ Notice that the mixture weights will be the same for all entries
+ for consistency, i.e. images and labels must be aggregated with the same weights,
+ but the random crops are not.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ batch_size: int,
+ label_keys: KeysCollection | None = None,
+ alpha: float = 1.0,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mixer = CutMix(batch_size, alpha)
+ self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
+
+ def __call__(self, data):
+ result = dict(data)
+ params = self.mixer.sample_params()
+ for k in self.keys:
+ result[k] = self.mixer.apply(params, data[k])
+ for k in self.label_keys:
+ result[k] = self.mixer.apply_on_labels(params, data[k])
+ return result
+
+
+class CutOut(Mixer):
+ """Cutout as described in the paper:
+ Terrance DeVries, Graham W. Taylor
+ Improved Regularization of Convolutional Neural Networks with Cutout
+ arXiv:1708.04552
+ """
+
+ @classmethod
+ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
+ weights, _ = params
+ nsamples, _, *dims = data.shape
+ if len(weights) != nsamples:
+ raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
+
+ mask = torch.ones_like(data)
+ for s, weight in enumerate(weights):
+ coords = [torch.randint(0, d, size=(1,)) for d in dims]
+ lengths = [d * sqrt(1 - weight) for d in dims]
+ idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
+ mask[s][idx] = 0
+
+ return mask * data
+
+ def __call__(self, data: torch.Tensor):
+ return self.apply(self.sample_params(), data)
+
+
+class CutOutd(MapTransform):
+ """Cutout as described in the paper:
+ Terrance DeVries, Graham W. Taylor
+ Improved Regularization of Convolutional Neural Networks with Cutout
+ arXiv:1708.04552
+
+ Notice that the cutout is different for every entry in the dictionary.
+ """
+
+ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.cutout = CutOut(batch_size)
+
+ def __call__(self, data):
+ result = dict(data)
+ for k in self.keys:
+ result[k] = self.cutout(data[k])
+ return result
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
new file mode 100644
index 0000000000..daa8ab0a02
--- /dev/null
+++ b/tests/test_regularization.py
@@ -0,0 +1,86 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from monai.regularization.mixup import MixUp, MixUpd, CutMix, CutMixd, CutOut
+import unittest
+
+
+class TestMixup(unittest.TestCase):
+ def test_mixup(self):
+ for dims in [2, 3]:
+ shape = (6, 3) + (32,) * dims
+ sample = torch.rand(*shape, dtype=torch.float32)
+ mixup = MixUp(6, 1.0)
+ output = mixup(sample)
+ self.assertEqual(output.shape, sample.shape)
+ self.assertTrue(any([not torch.allclose(sample, mixup(sample)) for _ in range(10)]))
+
+ with self.assertRaises(ValueError):
+ MixUp(6, -0.5)
+
+ mixup = MixUp(6, 0.5)
+ for dims in [2, 3]:
+ with self.assertRaises(ValueError):
+ shape = (5, 3) + (32,) * dims
+ sample = torch.rand(*shape, dtype=torch.float32)
+ mixup(sample)
+
+ def test_mixupd(self):
+ for dims in [2, 3]:
+ shape = (6, 3) + (32,) * dims
+ t = torch.rand(*shape, dtype=torch.float32)
+ sample = {"a": t, "b": t}
+ mixup = MixUpd(["a", "b"], 6)
+ output = mixup(sample)
+ self.assertTrue(torch.allclose(output["a"], output["b"]))
+
+ with self.assertRaises(ValueError):
+ MixUpd(["k1", "k2"], 6, -0.5)
+
+
+class TestCutMix(unittest.TestCase):
+ def test_cutmix(self):
+ for dims in [2, 3]:
+ shape = (6, 3) + (32,) * dims
+ sample = torch.rand(*shape, dtype=torch.float32)
+ cutmix = CutMix(6, 1.0)
+ output = cutmix(sample)
+ self.assertEqual(output.shape, sample.shape)
+ self.assertTrue(any([not torch.allclose(sample, cutmix(sample)) for _ in range(10)]))
+
+ def test_cutmixd(self):
+ for dims in [2, 3]:
+ shape = (6, 3) + (32,) * dims
+ t = torch.rand(*shape, dtype=torch.float32)
+ label = torch.randint(0, 1, shape)
+ sample = {"a": t, "b": t, "lbl1": label, "lbl2": label}
+ cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2"))
+ output = cutmix(sample)
+ # croppings are different on each application
+ self.assertTrue(not torch.allclose(output["a"], output["b"]))
+ # but mixing of labels is not affected by it
+ self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"]))
+
+
+class TestCutOut(unittest.TestCase):
+ def test_cutout(self):
+ for dims in [2, 3]:
+ shape = (6, 3) + (32,) * dims
+ sample = torch.rand(*shape, dtype=torch.float32)
+ cutout = CutOut(6, 1.0)
+ output = cutout(sample)
+ self.assertEqual(output.shape, sample.shape)
+ self.assertTrue(any([not torch.allclose(sample, cutout(sample)) for _ in range(10)]))
+
+
+if __name__ == "__main__":
+ unittest.main()
From 1528cdf0b3b53963c09b1c2034adc55ff8bdcb4d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Fri, 3 Nov 2023 22:17:27 +0100
Subject: [PATCH 02/88] added rst file
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/regularization.rst | 28 ++++++++++++++++++++++++++++
1 file changed, 28 insertions(+)
create mode 100644 docs/source/regularization.rst
diff --git a/docs/source/regularization.rst b/docs/source/regularization.rst
new file mode 100644
index 0000000000..e0decbca89
--- /dev/null
+++ b/docs/source/regularization.rst
@@ -0,0 +1,28 @@
+:github_url: https://github.com/Project-MONAI/MONAI
+
+.. _regularization:
+
+Regularization Strategies
+=========================
+
+Data Augmentation
+-------------------
+
+.. automodule:: monai.regularization
+.. currentmodule:: monai.regularization
+
+`MixUp`
+^^^^^^^
+.. autoclass:: MixUp
+ :members:
+
+
+`CutMix`
+^^^^^^^^
+.. autoclass:: CutMix
+ :members:
+
+`CutOut`
+^^^^^^^^
+.. autoclass:: CutOut
+ :members:
From ff456860b3a75117f386207ecc47c5414702d641 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Fri, 3 Nov 2023 21:26:19 +0000
Subject: [PATCH 03/88] [pre-commit.ci] auto fixes from pre-commit.com hooks
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
for more information, see https://pre-commit.ci
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/regularization/mixup.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/monai/regularization/mixup.py b/monai/regularization/mixup.py
index ae974a94ef..63b841dcaa 100644
--- a/monai/regularization/mixup.py
+++ b/monai/regularization/mixup.py
@@ -10,7 +10,6 @@
# limitations under the License.
from abc import abstractmethod
-from collections.abc import Sequence
from typing import Tuple
from monai.config import KeysCollection
import torch
From ea89145fa23f2b99980ed25cb4657849c845276c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Fri, 3 Nov 2023 23:10:09 +0100
Subject: [PATCH 04/88] added missing module
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/__init__.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/monai/__init__.py b/monai/__init__.py
index 638220f6df..ec865d203a 100644
--- a/monai/__init__.py
+++ b/monai/__init__.py
@@ -74,6 +74,7 @@
"metrics",
"networks",
"optimizers",
+ "regularization",
"transforms",
"utils",
"visualize",
From d35c1b39851c7f319e38b5a607957cd463cc681b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Mon, 6 Nov 2023 21:00:43 +0100
Subject: [PATCH 05/88] refactor code as submodule of transforms module
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/regularization.rst | 28 -----
docs/source/transforms.rst | 42 +++++++
docs/source/transforms_idx.rst | 10 ++
monai/__init__.py | 1 -
monai/transforms/__init__.py | 12 ++
.../regularization/__init__.py | 4 -
.../regularization/array.py} | 83 +-------------
monai/transforms/regularization/dictionary.py | 106 ++++++++++++++++++
tests/test_regularization.py | 2 +-
9 files changed, 173 insertions(+), 115 deletions(-)
delete mode 100644 docs/source/regularization.rst
rename monai/{ => transforms}/regularization/__init__.py (84%)
rename monai/{regularization/mixup.py => transforms/regularization/array.py} (65%)
create mode 100644 monai/transforms/regularization/dictionary.py
diff --git a/docs/source/regularization.rst b/docs/source/regularization.rst
deleted file mode 100644
index e0decbca89..0000000000
--- a/docs/source/regularization.rst
+++ /dev/null
@@ -1,28 +0,0 @@
-:github_url: https://github.com/Project-MONAI/MONAI
-
-.. _regularization:
-
-Regularization Strategies
-=========================
-
-Data Augmentation
--------------------
-
-.. automodule:: monai.regularization
-.. currentmodule:: monai.regularization
-
-`MixUp`
-^^^^^^^
-.. autoclass:: MixUp
- :members:
-
-
-`CutMix`
-^^^^^^^^
-.. autoclass:: CutMix
- :members:
-
-`CutOut`
-^^^^^^^^
-.. autoclass:: CutOut
- :members:
diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst
index 8990e7991d..bd3feb3497 100644
--- a/docs/source/transforms.rst
+++ b/docs/source/transforms.rst
@@ -661,6 +661,27 @@ Post-processing
:members:
:special-members: __call__
+Regularization
+^^^^^^^^^^^^^^
+
+`CutMix`
+""""""""
+.. autoclass:: CutMix
+ :members:
+ :special-members: __call__
+
+`CutOut`
+""""""""
+.. autoclass:: CutOut
+ :members:
+ :special-members: __call__
+
+`MixUp`
+"""""""
+.. autoclass:: MixUp
+ :members:
+ :special-members: __call__
+
Signal
^^^^^^^
@@ -1707,6 +1728,27 @@ Post-processing (Dict)
:members:
:special-members: __call__
+Regularization (Dict)
+^^^^^^^^^^^^^^^^^^^^^
+
+`CutMixd`
+"""""""""
+.. autoclass:: CutMixd
+ :members:
+ :special-members: __call__
+
+`CutOutd`
+"""""""""
+.. autoclass:: CutOutd
+ :members:
+ :special-members: __call__
+
+`MixUpd`
+""""""""
+.. autoclass:: MixUpd
+ :members:
+ :special-members: __call__
+
Signal (Dict)
^^^^^^^^^^^^^
diff --git a/docs/source/transforms_idx.rst b/docs/source/transforms_idx.rst
index f4d02a483f..650d45db71 100644
--- a/docs/source/transforms_idx.rst
+++ b/docs/source/transforms_idx.rst
@@ -74,6 +74,16 @@ Post-processing
post.array
post.dictionary
+Regularization
+^^^^^^^^^^^^^^
+
+.. autosummary::
+ :toctree: _gen
+ :nosignatures:
+
+ regularization.array
+ regularization.dictionary
+
Signal
^^^^^^
diff --git a/monai/__init__.py b/monai/__init__.py
index ec865d203a..638220f6df 100644
--- a/monai/__init__.py
+++ b/monai/__init__.py
@@ -74,7 +74,6 @@
"metrics",
"networks",
"optimizers",
- "regularization",
"transforms",
"utils",
"visualize",
diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py
index 2aa8fbf8a1..a7d7d17362 100644
--- a/monai/transforms/__init__.py
+++ b/monai/transforms/__init__.py
@@ -693,3 +693,15 @@
unravel_index,
where,
)
+from .regularization.array import MixUp, CutMix, CutOut
+from .regularization.dictionary import (
+ CutMixd,
+ CutMixD,
+ CutMixDict,
+ MixUpd,
+ MixUpD,
+ MixUpDict,
+ CutOutd,
+ CutOutD,
+ CutOutDict,
+)
diff --git a/monai/regularization/__init__.py b/monai/transforms/regularization/__init__.py
similarity index 84%
rename from monai/regularization/__init__.py
rename to monai/transforms/regularization/__init__.py
index 30455df3bd..1e97f89407 100644
--- a/monai/regularization/__init__.py
+++ b/monai/transforms/regularization/__init__.py
@@ -8,7 +8,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from .mixup import MixUp, MixUpd, CutMix, CutMixd
-
-__all__ = ["MixUp", "MixUpd", "CutMix", "CutMixd"]
diff --git a/monai/regularization/mixup.py b/monai/transforms/regularization/array.py
similarity index 65%
rename from monai/regularization/mixup.py
rename to monai/transforms/regularization/array.py
index 63b841dcaa..e61740a976 100644
--- a/monai/regularization/mixup.py
+++ b/monai/transforms/regularization/array.py
@@ -14,9 +14,10 @@
from monai.config import KeysCollection
import torch
from monai.transforms import Transform, MapTransform
-from monai.utils.misc import ensure_tuple
from math import sqrt, ceil
+__all__ = ["MixUp", "CutMix", "CutOut"]
+
class Mixer(Transform):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
@@ -70,33 +71,6 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
return self.apply(params, data), self.apply(params, labels)
-class MixUpd(MapTransform):
- """MixUp as described in:
- Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
- mixup: Beyond Empirical Risk Minimization, ICLR 2018
-
- Notice that the mixup transformation will be the same for all entries
- for consistency, i.e. images and labels must be applied the same augmenation.
- """
-
- def __init__(
- self,
- keys: KeysCollection,
- batch_size: int,
- alpha: float = 1.0,
- allow_missing_keys: bool = False,
- ) -> None:
- super().__init__(keys, allow_missing_keys)
- self.mixup = MixUp(batch_size, alpha)
-
- def __call__(self, data):
- result = dict(data)
- params = self.mixup.sample_params()
- for k in self.keys:
- result[k] = self.mixup.apply(params, data[k])
- return result
-
-
class CutMix(Mixer):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
@@ -133,39 +107,6 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
return (augmented, MixUp.apply(params, labels)) if labels is not None else augmented
-class CutMixd(MapTransform):
- """CutMix augmentation as described in:
- Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
- CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
- ICCV 2019
-
- Notice that the mixture weights will be the same for all entries
- for consistency, i.e. images and labels must be aggregated with the same weights,
- but the random crops are not.
- """
-
- def __init__(
- self,
- keys: KeysCollection,
- batch_size: int,
- label_keys: KeysCollection | None = None,
- alpha: float = 1.0,
- allow_missing_keys: bool = False,
- ) -> None:
- super().__init__(keys, allow_missing_keys)
- self.mixer = CutMix(batch_size, alpha)
- self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
-
- def __call__(self, data):
- result = dict(data)
- params = self.mixer.sample_params()
- for k in self.keys:
- result[k] = self.mixer.apply(params, data[k])
- for k in self.label_keys:
- result[k] = self.mixer.apply_on_labels(params, data[k])
- return result
-
-
class CutOut(Mixer):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor
@@ -191,23 +132,3 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
def __call__(self, data: torch.Tensor):
return self.apply(self.sample_params(), data)
-
-
-class CutOutd(MapTransform):
- """Cutout as described in the paper:
- Terrance DeVries, Graham W. Taylor
- Improved Regularization of Convolutional Neural Networks with Cutout
- arXiv:1708.04552
-
- Notice that the cutout is different for every entry in the dictionary.
- """
-
- def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
- super().__init__(keys, allow_missing_keys)
- self.cutout = CutOut(batch_size)
-
- def __call__(self, data):
- result = dict(data)
- for k in self.keys:
- result[k] = self.cutout(data[k])
- return result
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
new file mode 100644
index 0000000000..fab55be2b3
--- /dev/null
+++ b/monai/transforms/regularization/dictionary.py
@@ -0,0 +1,106 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from monai.config import KeysCollection
+from monai.transforms import MapTransform
+from monai.utils.misc import ensure_tuple
+from .array import MixUp, CutMix, CutOut
+
+__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]
+
+
+class MixUpd(MapTransform):
+ """MixUp as described in:
+ Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
+ mixup: Beyond Empirical Risk Minimization, ICLR 2018
+
+ Notice that the mixup transformation will be the same for all entries
+ for consistency, i.e. images and labels must be applied the same augmenation.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ batch_size: int,
+ alpha: float = 1.0,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mixup = MixUp(batch_size, alpha)
+
+ def __call__(self, data):
+ result = dict(data)
+ params = self.mixup.sample_params()
+ for k in self.keys:
+ result[k] = self.mixup.apply(params, data[k])
+ return result
+
+
+MixUpD = MixUpDict = MixUpd
+
+
+class CutMixd(MapTransform):
+ """CutMix augmentation as described in:
+ Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
+ ICCV 2019
+
+ Notice that the mixture weights will be the same for all entries
+ for consistency, i.e. images and labels must be aggregated with the same weights,
+ but the random crops are not.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ batch_size: int,
+ label_keys: KeysCollection | None = None,
+ alpha: float = 1.0,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mixer = CutMix(batch_size, alpha)
+ self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
+
+ def __call__(self, data):
+ result = dict(data)
+ params = self.mixer.sample_params()
+ for k in self.keys:
+ result[k] = self.mixer.apply(params, data[k])
+ for k in self.label_keys:
+ result[k] = self.mixer.apply_on_labels(params, data[k])
+ return result
+
+
+CutMixD = CutMixDict = CutMixd
+
+
+class CutOutd(MapTransform):
+ """Cutout as described in the paper:
+ Terrance DeVries, Graham W. Taylor
+ Improved Regularization of Convolutional Neural Networks with Cutout
+ arXiv:1708.04552
+
+ Notice that the cutout is different for every entry in the dictionary.
+ """
+
+ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.cutout = CutOut(batch_size)
+
+ def __call__(self, data):
+ result = dict(data)
+ for k in self.keys:
+ result[k] = self.cutout(data[k])
+ return result
+
+
+CutOutD = CutOutDict = CutOutd
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index daa8ab0a02..2fc3dee699 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -10,7 +10,7 @@
# limitations under the License.
import torch
-from monai.regularization.mixup import MixUp, MixUpd, CutMix, CutMixd, CutOut
+from monai.transforms import MixUp, MixUpd, CutMix, CutMixd, CutOut
import unittest
From 1c54f1cd8704a4c7afac423ca79ed363a13618b6 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 6 Nov 2023 20:01:22 +0000
Subject: [PATCH 06/88] [pre-commit.ci] auto fixes from pre-commit.com hooks
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
for more information, see https://pre-commit.ci
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index e61740a976..109fd414f4 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -11,9 +11,8 @@
from abc import abstractmethod
from typing import Tuple
-from monai.config import KeysCollection
import torch
-from monai.transforms import Transform, MapTransform
+from monai.transforms import Transform
from math import sqrt, ceil
__all__ = ["MixUp", "CutMix", "CutOut"]
From 9b0559331e2e31f3e34b7cccb7f70f7c2d55b751 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Tue, 7 Nov 2023 18:43:55 +0100
Subject: [PATCH 07/88] use the randomizable API
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 76 ++++++++++---------
monai/transforms/regularization/dictionary.py | 11 +--
2 files changed, 46 insertions(+), 41 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index 109fd414f4..d244da24a9 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -10,34 +10,36 @@
# limitations under the License.
from abc import abstractmethod
-from typing import Tuple
import torch
-from monai.transforms import Transform
+from monai.transforms import Transform, Randomizable
from math import sqrt, ceil
__all__ = ["MixUp", "CutMix", "CutOut"]
-class Mixer(Transform):
+class Mixer(Transform, Randomizable):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__()
if alpha <= 0:
raise ValueError(f"Expected positive number, but got {alpha = }")
- self._sampler = torch.distributions.beta.Beta(alpha, alpha)
+ self.alpha = alpha
self.batch_size = batch_size
- def sample_params(self):
+ @abstractmethod
+ def apply(cls, data: torch.Tensor):
+ raise NotImplementedError()
+
+ def randomize(self, data=None) -> None:
"""
Sometimes you need may to apply the same transform to different tensors.
- The idea is to get a sample and then apply it with apply_mixup() as often
- as needed.
+ The idea is to get a sample and then apply it with apply() as often
+ as needed. You need to call this method everytime you apply the transform to a new
+ batch.
"""
- return self._sampler.sample((self.batch_size,)), torch.randperm(self.batch_size)
-
- @classmethod
- @abstractmethod
- def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
- raise NotImplementedError()
+ self._params = (
+ torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
+ self.R.permutation(self.batch_size),
+ )
class MixUp(Mixer):
@@ -49,9 +51,8 @@ class MixUp(Mixer):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__(batch_size, alpha)
- @classmethod
- def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
- weight, perm = params
+ def apply(self, data: torch.Tensor):
+ weight, perm = self._params
nsamples, *dims = data.shape
if len(weight) != nsamples:
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
@@ -63,16 +64,15 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
return mixweight * data + (1 - mixweight) * data[perm, ...]
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
+ self.randomize()
if labels is None:
- return self.apply(self.sample_params(), data)
-
- params = self.sample_params()
- return self.apply(params, data), self.apply(params, labels)
+ return self.apply(data)
+ return self.apply(data), self.apply(labels)
class CutMix(Mixer):
"""CutMix augmentation as described in:
- Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
+ Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019
"""
@@ -80,9 +80,8 @@ class CutMix(Mixer):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__(batch_size, alpha)
- @classmethod
- def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
- weights, perm = params
+ def apply(self, data: torch.Tensor):
+ weights, perm = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
@@ -96,26 +95,30 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
return mask * data + (1 - mask) * data[perm, ...]
- @classmethod
- def apply_on_labels(cls, params: Tuple[torch.Tensor, torch.Tensor], labels: torch.Tensor):
- return MixUp.apply(params, labels)
+ def apply_on_labels(self, labels: torch.Tensor):
+ weights, perm = self._params
+ nsamples, *dims = labels.shape
+ if len(weights) != nsamples:
+ raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
+
+ mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
+ return mixweight * labels + (1 - mixweight) * labels[perm, ...]
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
- params = self.sample_params()
- augmented = self.apply(params, data)
- return (augmented, MixUp.apply(params, labels)) if labels is not None else augmented
+ self.randomize()
+ augmented = self.apply(data)
+ return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
class CutOut(Mixer):
"""Cutout as described in the paper:
- Terrance DeVries, Graham W. Taylor
- Improved Regularization of Convolutional Neural Networks with Cutout
+ Terrance DeVries, Graham W. Taylor.
+ Improved Regularization of Convolutional Neural Networks with Cutout,
arXiv:1708.04552
"""
- @classmethod
- def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
- weights, _ = params
+ def apply(self, data: torch.Tensor):
+ weights, _ = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
@@ -130,4 +133,5 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
return mask * data
def __call__(self, data: torch.Tensor):
- return self.apply(self.sample_params(), data)
+ self.randomize()
+ return self.apply(data)
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index fab55be2b3..99414653b2 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -37,10 +37,10 @@ def __init__(
self.mixup = MixUp(batch_size, alpha)
def __call__(self, data):
+ self.mixup.randomize()
result = dict(data)
- params = self.mixup.sample_params()
for k in self.keys:
- result[k] = self.mixup.apply(params, data[k])
+ result[k] = self.mixup.apply(data[k])
return result
@@ -71,12 +71,12 @@ def __init__(
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
def __call__(self, data):
+ self.mixer.randomize()
result = dict(data)
- params = self.mixer.sample_params()
for k in self.keys:
- result[k] = self.mixer.apply(params, data[k])
+ result[k] = self.mixer.apply(data[k])
for k in self.label_keys:
- result[k] = self.mixer.apply_on_labels(params, data[k])
+ result[k] = self.mixer.apply_on_labels(data[k])
return result
@@ -98,6 +98,7 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo
def __call__(self, data):
result = dict(data)
+ self.cutout.randomize()
for k in self.keys:
result[k] = self.cutout(data[k])
return result
From 83b2c98b4fd1221cf711a0eb0b1a1cd6a5bed60a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Tue, 7 Nov 2023 19:00:37 +0100
Subject: [PATCH 08/88] used types compatible with python <3.10
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index d244da24a9..6926657f5c 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -10,6 +10,7 @@
# limitations under the License.
from abc import abstractmethod
+from typing import Optional
import torch
from monai.transforms import Transform, Randomizable
from math import sqrt, ceil
@@ -63,7 +64,7 @@ def apply(self, data: torch.Tensor):
mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
return mixweight * data + (1 - mixweight) * data[perm, ...]
- def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
+ def __call__(self, data: torch.Tensor, labels: Optional[torch.Tensor] = None):
self.randomize()
if labels is None:
return self.apply(data)
@@ -104,7 +105,7 @@ def apply_on_labels(self, labels: torch.Tensor):
mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]
- def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
+ def __call__(self, data: torch.Tensor, labels: Optional[torch.Tensor] = None):
self.randomize()
augmented = self.apply(data)
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
From efd5e99d6b906c8168c84b42fc68762488daccad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Tue, 7 Nov 2023 19:03:18 +0100
Subject: [PATCH 09/88] used types compatible with python <3.10
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/dictionary.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 99414653b2..e6b201a76b 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from monai.config import KeysCollection
from monai.transforms import MapTransform
from monai.utils.misc import ensure_tuple
@@ -62,7 +63,7 @@ def __init__(
self,
keys: KeysCollection,
batch_size: int,
- label_keys: KeysCollection | None = None,
+ label_keys: Optional[KeysCollection] = None,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
From f6af5cffaa4c5b335d9bfa588ffdca5818af2fbf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Tue, 7 Nov 2023 19:49:02 +0100
Subject: [PATCH 10/88] fixed isort errors
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 8 ++++++--
monai/transforms/regularization/dictionary.py | 6 +++++-
2 files changed, 11 insertions(+), 3 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index 6926657f5c..29d4f45be5 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -9,11 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
+
from abc import abstractmethod
+from math import ceil, sqrt
from typing import Optional
+
import torch
-from monai.transforms import Transform, Randomizable
-from math import sqrt, ceil
+
+from monai.transforms import Randomizable, Transform
__all__ = ["MixUp", "CutMix", "CutOut"]
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index e6b201a76b..a01943333f 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -9,11 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
+
from typing import Optional
+
from monai.config import KeysCollection
from monai.transforms import MapTransform
from monai.utils.misc import ensure_tuple
-from .array import MixUp, CutMix, CutOut
+
+from .array import CutMix, CutOut, MixUp
__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]
From 8e7ec733f0c4ca45ae8359b2465cf31a0007cef3 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Tue, 7 Nov 2023 18:49:34 +0000
Subject: [PATCH 11/88] [pre-commit.ci] auto fixes from pre-commit.com hooks
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
for more information, see https://pre-commit.ci
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 5 ++---
monai/transforms/regularization/dictionary.py | 3 +--
2 files changed, 3 insertions(+), 5 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index 29d4f45be5..1d9b5538fb 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -13,7 +13,6 @@
from abc import abstractmethod
from math import ceil, sqrt
-from typing import Optional
import torch
@@ -68,7 +67,7 @@ def apply(self, data: torch.Tensor):
mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
return mixweight * data + (1 - mixweight) * data[perm, ...]
- def __call__(self, data: torch.Tensor, labels: Optional[torch.Tensor] = None):
+ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
if labels is None:
return self.apply(data)
@@ -109,7 +108,7 @@ def apply_on_labels(self, labels: torch.Tensor):
mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]
- def __call__(self, data: torch.Tensor, labels: Optional[torch.Tensor] = None):
+ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
augmented = self.apply(data)
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index a01943333f..3447b1e5a9 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -11,7 +11,6 @@
from __future__ import annotations
-from typing import Optional
from monai.config import KeysCollection
from monai.transforms import MapTransform
@@ -67,7 +66,7 @@ def __init__(
self,
keys: KeysCollection,
batch_size: int,
- label_keys: Optional[KeysCollection] = None,
+ label_keys: KeysCollection | None = None,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
From 5eb8c7fecdef2441fe96e904e6b7bbc125d3d82a Mon Sep 17 00:00:00 2001
From: monai-bot <64792179+monai-bot@users.noreply.github.com>
Date: Mon, 6 Nov 2023 08:57:08 +0000
Subject: [PATCH 12/88] auto updates (#7203)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: monai-bot
Signed-off-by: monai-bot
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_squeeze_unsqueeze.py | 1 -
tests/test_voxelmorph.py | 1 -
2 files changed, 2 deletions(-)
diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py
index 2db26a6bdc..130a214345 100644
--- a/tests/test_squeeze_unsqueeze.py
+++ b/tests/test_squeeze_unsqueeze.py
@@ -28,7 +28,6 @@
(torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)),
]
-
LEFT_CASES = [
(np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)),
diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py
index c51f70cbf5..53ef2fc18f 100644
--- a/tests/test_voxelmorph.py
+++ b/tests/test_voxelmorph.py
@@ -229,7 +229,6 @@
ILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5]
-
ILL_CASES_IN_SHAPE_0 = [ # moving and fixed image shape not match
{"spatial_dims": 3},
(1, 2, 96, 96, 48),
From 674812f712a486483a2e4873ab1d9bfc7713f62c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Tue, 14 Nov 2023 21:32:21 +0100
Subject: [PATCH 13/88] changes from command ./runtests.sh --autofix
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/__init__.py | 24 +++++++++----------
monai/transforms/regularization/dictionary.py | 7 +-----
tests/test_regularization.py | 8 +++++--
3 files changed, 19 insertions(+), 20 deletions(-)
diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py
index a7d7d17362..349533fb3e 100644
--- a/monai/transforms/__init__.py
+++ b/monai/transforms/__init__.py
@@ -336,6 +336,18 @@
VoteEnsembled,
VoteEnsembleDict,
)
+from .regularization.array import CutMix, CutOut, MixUp
+from .regularization.dictionary import (
+ CutMixd,
+ CutMixD,
+ CutMixDict,
+ CutOutd,
+ CutOutD,
+ CutOutDict,
+ MixUpd,
+ MixUpD,
+ MixUpDict,
+)
from .signal.array import (
SignalContinuousWavelet,
SignalFillEmpty,
@@ -693,15 +705,3 @@
unravel_index,
where,
)
-from .regularization.array import MixUp, CutMix, CutOut
-from .regularization.dictionary import (
- CutMixd,
- CutMixD,
- CutMixDict,
- MixUpd,
- MixUpD,
- MixUpDict,
- CutOutd,
- CutOutD,
- CutOutDict,
-)
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 3447b1e5a9..86a361b8a4 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -11,7 +11,6 @@
from __future__ import annotations
-
from monai.config import KeysCollection
from monai.transforms import MapTransform
from monai.utils.misc import ensure_tuple
@@ -31,11 +30,7 @@ class MixUpd(MapTransform):
"""
def __init__(
- self,
- keys: KeysCollection,
- batch_size: int,
- alpha: float = 1.0,
- allow_missing_keys: bool = False,
+ self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index 2fc3dee699..8b974e392d 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -9,10 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-from monai.transforms import MixUp, MixUpd, CutMix, CutMixd, CutOut
+from __future__ import annotations
+
import unittest
+import torch
+
+from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd
+
class TestMixup(unittest.TestCase):
def test_mixup(self):
From c43f8cadd524762534224e504856e6584ffae65e Mon Sep 17 00:00:00 2001
From: elitap
Date: Wed, 15 Nov 2023 12:00:21 +0100
Subject: [PATCH 14/88] fix useless error msg in nnunetv2runner (#7217)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes # fix useless error msg in nnunetv2runner
### Description
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
Signed-off-by: elitap
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/nnunet/nnunetv2_runner.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py
index a3357cd9b3..e62809403e 100644
--- a/monai/apps/nnunet/nnunetv2_runner.py
+++ b/monai/apps/nnunet/nnunetv2_runner.py
@@ -275,8 +275,8 @@ def convert_dataset(self):
num_input_channels=num_input_channels,
output_datafolder=raw_data_foldername,
)
- except BaseException:
- logger.warning("Input config may be incorrect. Detail info: error/exception message is:\n {err}")
+ except BaseException as err:
+ logger.warning(f"Input config may be incorrect. Detail info: error/exception message is:\n {err}")
return
def convert_msd_dataset(self, data_dir: str, overwrite_id: str | None = None, n_proc: int = -1) -> None:
From cad947e58f0755837b468019e0c86e50ccb0ad49 Mon Sep 17 00:00:00 2001
From: Felix Schnabel
Date: Thu, 16 Nov 2023 19:49:14 +0100
Subject: [PATCH 15/88] Fixup mypy 1.7.0 errors (#7231)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7230.
### Description
Fix the typing issues and the deprecation.
Also always run type checking with Linux environment, since
ForkServerContext is not available on Windows.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
---------
Signed-off-by: Felix Schnabel
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.gitignore | 1 +
monai/apps/auto3dseg/data_analyzer.py | 2 +-
monai/apps/deepgrow/interaction.py | 2 +-
monai/apps/pathology/metrics/lesion_froc.py | 5 ++++-
monai/metrics/utils.py | 4 ++--
monai/transforms/io/array.py | 2 +-
requirements-dev.txt | 2 +-
setup.cfg | 2 +-
8 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/.gitignore b/.gitignore
index 8c66d4a651..437677d2bb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -115,6 +115,7 @@ venv.bak/
examples/scd_lvsegs.npz
temp/
.idea/
+.dmypy.json
*~
diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py
index 2c485f03eb..9280fb5be5 100644
--- a/monai/apps/auto3dseg/data_analyzer.py
+++ b/monai/apps/auto3dseg/data_analyzer.py
@@ -210,7 +210,7 @@ def get_all_case_stats(self, key="training", transform_list=None):
nprocs = torch.cuda.device_count()
logger.info(f"Found {nprocs} GPUs for data analyzing!")
if nprocs > 1:
- tmp_ctx = get_context("forkserver")
+ tmp_ctx: Any = get_context("forkserver")
with tmp_ctx.Manager() as manager:
manager_list = manager.list()
processes = []
diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py
index 88211c31e3..fa3a28bfef 100644
--- a/monai/apps/deepgrow/interaction.py
+++ b/monai/apps/deepgrow/interaction.py
@@ -49,7 +49,7 @@ def __init__(
if not isinstance(transforms, Compose):
transforms = Compose(transforms)
- self.transforms = transforms
+ self.transforms: Compose = transforms
self.max_interactions = max_interactions
self.train = train
self.key_probability = key_probability
diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py
index 0899de9a76..f4bf51ab28 100644
--- a/monai/apps/pathology/metrics/lesion_froc.py
+++ b/monai/apps/pathology/metrics/lesion_froc.py
@@ -11,7 +11,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Iterable
import numpy as np
@@ -94,6 +94,9 @@ def prepare_inference_result(self, sample: dict) -> tuple[np.ndarray, np.ndarray
nms_outputs = self.nms(probs_map=prob_map, resolution_level=sample["level"])
# separate nms outputs
+ probs: Iterable[Any]
+ x_coord: Iterable[Any]
+ y_coord: Iterable[Any]
if nms_outputs:
probs, x_coord, y_coord = zip(*nms_outputs)
else:
diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py
index c139fc35ed..4d4e6570c5 100644
--- a/monai/metrics/utils.py
+++ b/monai/metrics/utils.py
@@ -14,7 +14,7 @@
import warnings
from functools import lru_cache, partial
from types import ModuleType
-from typing import Any, Sequence
+from typing import Any, Iterable, Sequence
import numpy as np
import torch
@@ -383,7 +383,7 @@ def remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor
by_size: if True, largest instance will be assigned a smaller id.
"""
- pred_id = list(pred.unique())
+ pred_id: Iterable[Any] = list(pred.unique())
# the original implementation has the limitation that if there is no 0 in pred, error will happen
pred_id = [i for i in pred_id if i != 0]
diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py
index b36c011822..cd7e4ef090 100644
--- a/monai/transforms/io/array.py
+++ b/monai/transforms/io/array.py
@@ -268,7 +268,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
break
if img is None or reader is None:
- if isinstance(filename, tuple) and len(filename) == 1:
+ if isinstance(filename, Sequence) and len(filename) == 1:
filename = filename[0]
msg = "\n".join([f"{e}" for e in err])
raise RuntimeError(
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 38715b8449..6332d5b0a5 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -22,7 +22,7 @@ isort>=5.1
ruff
pytype>=2020.6.1; platform_system != "Windows"
types-pkg_resources
-mypy>=0.790
+mypy>=1.5.0
ninja
torchvision
psutil
diff --git a/setup.cfg b/setup.cfg
index d6c9b4f190..123da68dfa 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -228,7 +228,7 @@ pretty = False
# Warns about per-module sections in the config file that do not match any files processed when invoking mypy.
warn_unused_configs = True
# Make arguments prepended via Concatenate be truly positional-only.
-strict_concatenate = True
+extra_checks = True
# Allows variables to be redefined with an arbitrary type,
# as long as the redefinition is in the same block and nesting level as the original definition.
# allow_redefinition = True
From 81097abcc67538168518a95f3f4661784915311c Mon Sep 17 00:00:00 2001
From: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com>
Date: Fri, 17 Nov 2023 11:21:56 +0800
Subject: [PATCH 16/88] add Yun Liu to user list to trigger blossom-ci [skip
ci] (#7239)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes # .
### Description
A few sentences describing the changes proposed in this pull request.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Add new developer to blossom-ci trigger list
Signed-off-by: YanxuanLiu
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/blossom-ci.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml
index 5e4612c6c5..1d6ee8a46c 100644
--- a/.github/workflows/blossom-ci.yml
+++ b/.github/workflows/blossom-ci.yml
@@ -34,6 +34,7 @@ jobs:
wyli,\
pxLi,\
YanxuanLiu,\
+ KumoLiu,\
', format('{0},', github.actor)) && github.event.comment.body == '/build'
steps:
- name: Check if comment is issued by authorized person
From da27e2e356db6a23617a06f66a653963c333a5bc Mon Sep 17 00:00:00 2001
From: ytl0623
Date: Fri, 17 Nov 2023 16:18:07 +0800
Subject: [PATCH 17/88] =?UTF-8?q?Replace=20single=20quotation=20marks=20wi?=
=?UTF-8?q?th=20double=20quotation=20marks=20to=20install=E2=80=A6=20(#723?=
=?UTF-8?q?4)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
… MONAI with all dependencies on Windows
Fixes #6118
### Description
The Windows shell doesn't recognize single quotes to delimit a string at
all, so on Windows you'll need to use double quotes. It's the same
command no matter which type of quotes you use; after the shell does its
processing, the argument is passed to pip with the quotation marks
already removed.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: ytl0623
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/installation.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/installation.md b/docs/source/installation.md
index 88107c9487..d77253f0f9 100644
--- a/docs/source/installation.md
+++ b/docs/source/installation.md
@@ -228,7 +228,7 @@ Alternatively, to install all optional dependencies:
```bash
git clone https://github.com/Project-MONAI/MONAI.git
cd MONAI/
-pip install -e '.[all]'
+pip install -e ".[all]"
```
To install all optional dependencies with `pip` based on MONAI development environment settings:
From f665179ffc9791641a1a98c56caac2df7d1ff61f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?D=C5=BEenan=20Zuki=C4=87?=
Date: Fri, 17 Nov 2023 04:56:52 -0500
Subject: [PATCH 18/88] Update bug_report.md (#7213)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Avoid syntax error on Windows
### Description
```log
(.venv) C:\Dev\Kitware\python>python -c 'import monai; monai.config.print_debug_info()'
File "", line 1
'import
^
SyntaxError: unterminated string literal (detected at line 1)
```
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: Dženan Zukić
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/ISSUE_TEMPLATE/bug_report.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index cebcdfc917..4ae6c07732 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -27,7 +27,7 @@ If applicable, add screenshots to help explain your problem.
Ensuring you use the relevant python executable, please paste the output of:
```
-python -c 'import monai; monai.config.print_debug_info()'
+python -c "import monai; monai.config.print_debug_info()"
```
**Additional context**
From 36beba13ae85f8b016f42e7adac605cc029a2785 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 17 Nov 2023 22:14:17 +0800
Subject: [PATCH 19/88] Add cache option in `GridPatchDataset` (#7180)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Part of #6904
### Description
- Fix inefficient patching in `PatchDataset`
- Add cache option in `GridPatchDataset`
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: KumoLiu
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/data/grid_dataset.py | 218 ++++++++++++++++++++++++++++++------
tests/test_grid_dataset.py | 55 +++++++--
tests/test_patch_dataset.py | 15 ++-
3 files changed, 242 insertions(+), 46 deletions(-)
diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py
index 06954e9f11..9079032e6f 100644
--- a/monai/data/grid_dataset.py
+++ b/monai/data/grid_dataset.py
@@ -11,18 +11,30 @@
from __future__ import annotations
-from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
+import sys
+import warnings
+from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence
from copy import deepcopy
+from multiprocessing.managers import ListProxy
+from multiprocessing.pool import ThreadPool
+from typing import TYPE_CHECKING
import numpy as np
+import torch
from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayTensor
-from monai.data.dataset import Dataset
from monai.data.iterable_dataset import IterableDataset
-from monai.data.utils import iter_patch
-from monai.transforms import apply_transform
-from monai.utils import NumpyPadMode, ensure_tuple, first
+from monai.data.utils import iter_patch, pickle_hashing
+from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous
+from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import
+
+if TYPE_CHECKING:
+ from tqdm import tqdm
+
+ has_tqdm = True
+else:
+ tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]
@@ -184,6 +196,25 @@ class GridPatchDataset(IterableDataset):
see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.
transform: a callable data transform operates on the patches.
with_coordinates: whether to yield the coordinates of each patch, default to `True`.
+ cache: whether to use cache mache mechanism, default to `False`.
+ see also: :py:class:`monai.data.CacheDataset`.
+ cache_num: number of items to be cached. Default is `sys.maxsize`.
+ will take the minimum of (cache_num, data_length x cache_rate, data_length).
+ cache_rate: percentage of cached data in total, default is 1.0 (cache all).
+ will take the minimum of (cache_num, data_length x cache_rate, data_length).
+ num_workers: the number of worker threads if computing cache in the initialization.
+ If num_workers is None then the number returned by os.cpu_count() is used.
+ If a value less than 1 is specified, 1 will be used instead.
+ progress: whether to display a progress bar.
+ copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
+ default to `True`. if the random transforms don't modify the cached content
+ (for example, randomly crop from the cached image and deepcopy the crop region)
+ or if every cache item is only used once in a `multi-processing` environment,
+ may set `copy=False` for better performance.
+ as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
+ it may help improve the performance of following logic.
+ hash_func: a callable to compute hash from data items to be cached.
+ defaults to `monai.data.utils.pickle_hashing`.
"""
@@ -193,27 +224,148 @@ def __init__(
patch_iter: Callable,
transform: Callable | None = None,
with_coordinates: bool = True,
+ cache: bool = False,
+ cache_num: int = sys.maxsize,
+ cache_rate: float = 1.0,
+ num_workers: int | None = 1,
+ progress: bool = True,
+ copy_cache: bool = True,
+ as_contiguous: bool = True,
+ hash_func: Callable[..., bytes] = pickle_hashing,
) -> None:
super().__init__(data=data, transform=None)
+ if transform is not None and not isinstance(transform, Compose):
+ transform = Compose(transform)
self.patch_iter = patch_iter
self.patch_transform = transform
self.with_coordinates = with_coordinates
+ self.set_num = cache_num
+ self.set_rate = cache_rate
+ self.progress = progress
+ self.copy_cache = copy_cache
+ self.as_contiguous = as_contiguous
+ self.hash_func = hash_func
+ self.num_workers = num_workers
+ if self.num_workers is not None:
+ self.num_workers = max(int(self.num_workers), 1)
+ self._cache: list | ListProxy = []
+ self._cache_other: list | ListProxy = []
+ self.cache = cache
+ self.first_random: int | None = None
+ if self.patch_transform is not None:
+ self.first_random = self.patch_transform.get_index_of_first(
+ lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
+ )
- def __iter__(self):
- for image in super().__iter__():
- for patch, *others in self.patch_iter(image):
- out_patch = patch
- if self.patch_transform is not None:
- out_patch = apply_transform(self.patch_transform, patch, map_items=False)
- if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
- yield out_patch, others[0]
- else:
- yield out_patch
+ if self.cache:
+ if isinstance(data, Iterator):
+ raise TypeError("Data can not be iterator when cache is True")
+ self.set_data(data) # type: ignore
+
+ def set_data(self, data: Sequence) -> None:
+ """
+ Set the input data and run deterministic transforms to generate cache content.
+
+ Note: should call this func after an entire epoch and must set `persistent_workers=False`
+ in PyTorch DataLoader, because it needs to create new worker processes based on new
+ generated cache content.
+
+ """
+ self.data = data
+
+ # only compute cache for the unique items of dataset, and record the last index for duplicated items
+ mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
+ self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping))
+ self._hash_keys = list(mapping)[: self.cache_num]
+ indices = list(mapping.values())[: self.cache_num]
+ self._cache, self._cache_other = zip(*self._fill_cache(indices)) # type: ignore
+
+ def _fill_cache(self, indices=None) -> list:
+ """
+ Compute and fill the cache content from data source.
+
+ Args:
+ indices: target indices in the `self.data` source to compute cache.
+ if None, use the first `cache_num` items.
+
+ """
+ if self.cache_num <= 0:
+ return []
+ if indices is None:
+ indices = list(range(self.cache_num))
+ if self.progress and not has_tqdm:
+ warnings.warn("tqdm is not installed, will not show the caching progress bar.")
+
+ pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v)
+ with ThreadPool(self.num_workers) as p:
+ return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))
+
+ def _load_cache_item(self, idx: int):
+ """
+ Args:
+ idx: the index of the input data sequence.
+ """
+ item = self.data[idx] # type: ignore
+ patch_cache, other_cache = [], []
+ for patch, *others in self.patch_iter(item):
+ if self.first_random is not None:
+ patch = self.patch_transform(patch, end=self.first_random, threading=True) # type: ignore
+
+ if self.as_contiguous:
+ patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format)
+ if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
+ other_cache.append(others[0])
+ patch_cache.append(patch)
+ return patch_cache, other_cache
+
+ def _generate_patches(self, src, **apply_args):
+ """
+ yield patches optionally post-processed by transform.
+ Args:
+ src: a iterable of image patches.
+ apply_args: other args for `self.patch_transform`.
+
+ """
+ for patch, *others in src:
+ out_patch = patch
+ if self.patch_transform is not None:
+ out_patch = self.patch_transform(patch, **apply_args)
+ if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
+ yield out_patch, others[0]
+ else:
+ yield out_patch
-class PatchDataset(Dataset):
+ def __iter__(self):
+ if self.cache:
+ cache_index = None
+ for image in super().__iter__():
+ key = self.hash_func(image)
+ if key in self._hash_keys:
+ # if existing in cache, try to get the index in cache
+ cache_index = self._hash_keys.index(key)
+ if cache_index is None:
+ # no cache for this index, execute all the transforms directly
+ yield from self._generate_patches(self.patch_iter(image))
+ else:
+ if self._cache is None:
+ raise RuntimeError(
+ "Cache buffer is not initialized, please call `set_data()` before epoch begins."
+ )
+ data = self._cache[cache_index] # type: ignore
+ other = self._cache_other[cache_index] # type: ignore
+
+ # load data from cache and execute from the first random transform
+ data = deepcopy(data) if self.copy_cache else data
+ yield from self._generate_patches(zip(data, other), start=self.first_random)
+ else:
+ for image in super().__iter__():
+ yield from self._generate_patches(self.patch_iter(image))
+
+
+class PatchDataset(IterableDataset):
"""
- returns a patch from an image dataset.
+ Yields patches from data read from an image dataset.
The patches are generated by a user-specified callable `patch_func`,
and are optionally post-processed by `transform`.
For example, to generate random patch samples from an image dataset:
@@ -263,26 +415,26 @@ def __init__(
samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
transform: transform applied to each patch.
"""
- super().__init__(data=data, transform=transform)
+ super().__init__(data=data, transform=None)
self.patch_func = patch_func
if samples_per_image <= 0:
raise ValueError("sampler_per_image must be a positive integer.")
self.samples_per_image = int(samples_per_image)
+ self.patch_transform = transform
def __len__(self) -> int:
- return len(self.data) * self.samples_per_image
-
- def _transform(self, index: int):
- image_id = int(index / self.samples_per_image)
- image = self.data[image_id]
- patches = self.patch_func(image)
- if len(patches) != self.samples_per_image:
- raise RuntimeWarning(
- f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
- )
- patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1)
- patch = patches[patch_id]
- if self.transform is not None:
- patch = apply_transform(self.transform, patch, map_items=False)
- return patch
+ return len(self.data) * self.samples_per_image # type: ignore
+
+ def __iter__(self):
+ for image in super().__iter__():
+ patches = self.patch_func(image)
+ if len(patches) != self.samples_per_image:
+ raise RuntimeWarning(
+ f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
+ )
+ for patch in patches:
+ out_patch = patch
+ if self.patch_transform is not None:
+ out_patch = apply_transform(self.patch_transform, patch, map_items=False)
+ yield out_patch
diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py
index ba33547260..d937a5e266 100644
--- a/tests/test_grid_dataset.py
+++ b/tests/test_grid_dataset.py
@@ -108,11 +108,10 @@ def test_shape(self):
self.assertEqual(sorted(output), sorted(expected))
def test_loading_array(self):
- set_determinism(seed=1234)
# test sequence input data with images
images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
# image level
- patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
+ patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234)
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity)
# use the grid patch dataset
@@ -120,7 +119,7 @@ def test_loading_array(self):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
- np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
+ np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -129,9 +128,7 @@ def test_loading_array(self):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
- np.array(
- [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
- ),
+ np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
@@ -164,7 +161,7 @@ def test_loading_dict(self):
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
np.testing.assert_allclose(
item[0]["image"],
- np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
+ np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -173,15 +170,53 @@ def test_loading_dict(self):
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
np.testing.assert_allclose(
item[0]["image"],
- np.array(
- [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
- ),
+ np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5
)
+ def test_set_data(self):
+ from monai.transforms import Compose, Lambda, RandLambda
+
+ images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
+
+ transform = Compose(
+ [Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False
+ )
+ patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
+ dataset = GridPatchDataset(
+ data=images,
+ patch_iter=patch_iter,
+ transform=transform,
+ cache=True,
+ cache_rate=1.0,
+ copy_cache=not sys.platform == "linux",
+ )
+
+ num_workers = 2 if sys.platform == "linux" else 0
+ for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
+ np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
+ np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)
+ np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
+ # simulate another epoch, the cache content should not be modified
+ for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
+ np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
+ np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)
+ np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
+
+ # update the datalist and fill the cache content
+ data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)]
+ dataset.set_data(data=data_list2)
+ # rerun with updated cache content
+ for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
+ np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
+ np.testing.assert_allclose(
+ item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4
+ )
+ np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py
index 7d66bdccbb..eb705f0c61 100644
--- a/tests/test_patch_dataset.py
+++ b/tests/test_patch_dataset.py
@@ -37,7 +37,10 @@ def test_shape(self):
n_workers = 0 if sys.platform == "win32" else 2
for item in DataLoader(result, batch_size=3, num_workers=n_workers):
output.append("".join(item))
- expected = ["vwx", "yzh", "ell", "owo", "rld"]
+ if n_workers == 0:
+ expected = ["vwx", "yzh", "ell", "owo", "rld"]
+ else:
+ expected = ["vwx", "hel", "yzw", "lo", "orl", "d"]
self.assertEqual(output, expected)
def test_loading_array(self):
@@ -61,7 +64,7 @@ def test_loading_array(self):
np.testing.assert_allclose(
item[0],
np.array(
- [[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]]
+ [[[4.970372, 5.970372, 6.970372], [8.970372, 9.970372, 10.970372], [12.970372, 13.970372, 14.970372]]]
),
rtol=1e-5,
)
@@ -71,7 +74,13 @@ def test_loading_array(self):
np.testing.assert_allclose(
item[0],
np.array(
- [[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]]
+ [
+ [
+ [5.028125, 6.028125, 7.028125],
+ [9.028125, 10.028125, 11.028125],
+ [13.028125, 14.028125, 15.028125],
+ ]
+ ]
),
rtol=1e-5,
)
From af45bb2318e7c54f4d89c69a57c7ab91c02dc7b5 Mon Sep 17 00:00:00 2001
From: Ishan Dutta
Date: Sun, 19 Nov 2023 21:49:16 +0530
Subject: [PATCH 20/88] :memo: [array] Add examples for EnsureType and
CastToType (#7245)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7101
### Description
Added examples in the docstrings for `EnsureType` and `CastToType`
transforms which show how they function under different circumstances.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: Ishan Dutta
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/utility/array.py | 35 ++++++++++++++++++++++++++++++-
1 file changed, 34 insertions(+), 1 deletion(-)
diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py
index 9aad12ef90..caf02d7b00 100644
--- a/monai/transforms/utility/array.py
+++ b/monai/transforms/utility/array.py
@@ -333,6 +333,23 @@ class CastToType(Transform):
"""
Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to
specified PyTorch data type.
+
+ Example:
+ >>> import numpy as np
+ >>> import torch
+ >>> transform = CastToType(dtype=np.float32)
+
+ >>> # Example with a numpy array
+ >>> img_np = np.array([0, 127, 255], dtype=np.uint8)
+ >>> img_np_casted = transform(img_np)
+ >>> img_np_casted
+ array([ 0. , 127. , 255. ], dtype=float32)
+
+ >>> # Example with a PyTorch tensor
+ >>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8)
+ >>> img_tensor_casted = transform(img_tensor)
+ >>> img_tensor_casted
+ tensor([ 0., 127., 255.]) # dtype is float32
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
@@ -413,10 +430,26 @@ class EnsureType(Transform):
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
- E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,
if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.
+ Example with wrap_sequence=True:
+ >>> import numpy as np
+ >>> import torch
+ >>> transform = EnsureType(data_type="tensor", wrap_sequence=True)
+ >>> # Converting a list to a tensor
+ >>> data_list = [1, 2., 3]
+ >>> tensor_data = transform(data_list)
+ >>> tensor_data
+ tensor([1., 2., 3.]) # All elements have dtype float32
+
+ Example with wrap_sequence=False:
+ >>> transform = EnsureType(data_type="tensor", wrap_sequence=False)
+ >>> # Converting each element in a list to individual tensors
+ >>> data_list = [1, 2, 3]
+ >>> tensors_list = transform(data_list)
+ >>> tensors_list
+ [tensor(1), tensor(2.), tensor(3)] # Only second element is float32 rest are int64
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
From fede45b0e663eb171b980ff4c105a02263e99e14 Mon Sep 17 00:00:00 2001
From: Ishan Dutta
Date: Mon, 20 Nov 2023 08:02:39 +0530
Subject: [PATCH 21/88] :hammer: [dataset] Handle corrupted cached file in
PersistentDataset (#7244)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #5723
### Description
Corrupted cached files in the PersistentDataset cause the exception:
`RuntimeError: Invalid magic number; corrupt file?`
With this PR we handle that case in the try-except block and continue
the usual functionality if the cached file was absent.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Ishan Dutta
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/data/dataset.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/monai/data/dataset.py b/monai/data/dataset.py
index 5e403d6fdb..eba850225d 100644
--- a/monai/data/dataset.py
+++ b/monai/data/dataset.py
@@ -387,6 +387,12 @@ def _cachecheck(self, item_transformed):
except PermissionError as e:
if sys.platform != "win32":
raise e
+ except RuntimeError as e:
+ if "Invalid magic number; corrupt file" in str(e):
+ warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
+ hashfile.unlink()
+ else:
+ raise e
_item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed
if hashfile is None:
From 4bc75041cc68804ce29694fc8cddafbe5ee537d0 Mon Sep 17 00:00:00 2001
From: monai-bot <64792179+monai-bot@users.noreply.github.com>
Date: Mon, 20 Nov 2023 08:34:18 +0000
Subject: [PATCH 22/88] auto updates (#7247)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: monai-bot
Signed-off-by: monai-bot
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/detection/transforms/box_ops.py | 2 +-
monai/data/grid_dataset.py | 4 ++--
monai/data/image_writer.py | 2 +-
monai/data/wsi_reader.py | 2 +-
monai/metrics/utils.py | 2 +-
monai/networks/nets/resnet.py | 6 +++---
monai/transforms/croppad/array.py | 8 +++-----
monai/transforms/spatial/array.py | 2 +-
monai/transforms/utility/array.py | 2 +-
monai/transforms/utils.py | 8 ++++----
tests/test_inverse.py | 2 +-
tests/utils.py | 4 ++--
12 files changed, 21 insertions(+), 23 deletions(-)
diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py
index fb870c952e..404854c4c0 100644
--- a/monai/apps/detection/transforms/box_ops.py
+++ b/monai/apps/detection/transforms/box_ops.py
@@ -407,7 +407,7 @@ def rot90_boxes(
spatial_dims: int = get_spatial_dims(boxes=boxes)
spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims))
- axes = ensure_tuple(axes) # type: ignore
+ axes = ensure_tuple(axes)
if len(axes) != 2:
raise ValueError("len(axes) must be 2.")
diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py
index 9079032e6f..689138179a 100644
--- a/monai/data/grid_dataset.py
+++ b/monai/data/grid_dataset.py
@@ -352,8 +352,8 @@ def __iter__(self):
raise RuntimeError(
"Cache buffer is not initialized, please call `set_data()` before epoch begins."
)
- data = self._cache[cache_index] # type: ignore
- other = self._cache_other[cache_index] # type: ignore
+ data = self._cache[cache_index]
+ other = self._cache_other[cache_index]
# load data from cache and execute from the first random transform
data = deepcopy(data) if self.copy_cache else data
diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py
index db0bfa96b8..b9e8b9e68e 100644
--- a/monai/data/image_writer.py
+++ b/monai/data/image_writer.py
@@ -276,7 +276,7 @@ def resample_if_needed(
# convert back at the end
if isinstance(output_array, MetaTensor):
output_array.applied_operations = []
- data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore
+ data_array, *_ = convert_data_type(output_array, output_type=orig_type)
affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore
return data_array[0], affine
diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py
index 54e12eb0cd..b31d4d9c3a 100644
--- a/monai/data/wsi_reader.py
+++ b/monai/data/wsi_reader.py
@@ -111,7 +111,7 @@ def __init__(
self.set_device(device)
self.mode = mode
self.kwargs = kwargs
- self.mpp: tuple[float, float] | None = ensure_tuple_rep(mpp, 2) if mpp is not None else None # type: ignore
+ self.mpp: tuple[float, float] | None = ensure_tuple_rep(mpp, 2) if mpp is not None else None
self.power = power
self.mpp_rtol = mpp_rtol
self.mpp_atol = mpp_atol
diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py
index 4d4e6570c5..62e6520b96 100644
--- a/monai/metrics/utils.py
+++ b/monai/metrics/utils.py
@@ -205,7 +205,7 @@ def get_mask_edges(
or_vol = seg_pred | seg_gt
if not or_vol.any():
pred, gt = lib.zeros(seg_pred.shape, dtype=bool), lib.zeros(seg_gt.shape, dtype=bool)
- return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore
+ return (pred, gt) if spacing is None else (pred, gt, pred, gt)
channel_first = [seg_pred[None], seg_gt[None], or_vol[None]]
if spacing is None and not use_cucim: # cpu only erosion
seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool)
diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py
index fca73f4de3..34a4b7057e 100644
--- a/monai/networks/nets/resnet.py
+++ b/monai/networks/nets/resnet.py
@@ -239,9 +239,9 @@ def __init__(
self.conv1 = conv_type(
n_input_channels,
self.in_planes,
- kernel_size=conv1_kernel_size, # type: ignore
- stride=conv1_stride, # type: ignore
- padding=tuple(k // 2 for k in conv1_kernel_size), # type: ignore
+ kernel_size=conv1_kernel_size,
+ stride=conv1_stride,
+ padding=tuple(k // 2 for k in conv1_kernel_size),
bias=False,
)
self.bn1 = norm_type(self.in_planes)
diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py
index 6a3798e7ba..ce3701b263 100644
--- a/monai/transforms/croppad/array.py
+++ b/monai/transforms/croppad/array.py
@@ -386,7 +386,7 @@ def compute_slices(
if roi_slices:
if not all(s.step is None or s.step == 1 for s in roi_slices):
raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.")
- return ensure_tuple(roi_slices) # type: ignore
+ return ensure_tuple(roi_slices)
else:
if roi_center is not None and roi_size is not None:
roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
@@ -408,10 +408,8 @@ def compute_slices(
roi_end_t = torch.maximum(roi_end_t, roi_start_t)
# convert to slices (accounting for 1d)
if roi_start_t.numel() == 1:
- return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore
- return ensure_tuple( # type: ignore
- [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
- )
+ return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))])
+ return ensure_tuple([slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())])
def __call__( # type: ignore[override]
self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None
diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py
index 9d55aa013b..8ad86b72dd 100644
--- a/monai/transforms/spatial/array.py
+++ b/monai/transforms/spatial/array.py
@@ -1157,7 +1157,7 @@ def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), lazy: boo
"""
LazyTransform.__init__(self, lazy=lazy)
self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3
- spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore
+ spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes)
if len(spatial_axes_) != 2:
raise ValueError(f"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.")
self.spatial_axes = spatial_axes_
diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py
index caf02d7b00..2322f2123f 100644
--- a/monai/transforms/utility/array.py
+++ b/monai/transforms/utility/array.py
@@ -372,7 +372,7 @@ def __call__(self, img: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None)
TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``.
"""
- return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0] # type: ignore
+ return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0]
class ToTensor(Transform):
diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py
index 678219991f..e282ecff24 100644
--- a/monai/transforms/utils.py
+++ b/monai/transforms/utils.py
@@ -521,7 +521,7 @@ def correct_crop_centers(
for c, v_s, v_e in zip(centers, valid_start, valid_end):
center_i = min(max(c, v_s), v_e - 1)
valid_centers.append(int(center_i))
- return ensure_tuple(valid_centers) # type: ignore
+ return ensure_tuple(valid_centers)
def generate_pos_neg_label_crop_centers(
@@ -579,7 +579,7 @@ def generate_pos_neg_label_crop_centers(
# shift center to range of valid centers
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))
- return ensure_tuple(centers) # type: ignore
+ return ensure_tuple(centers)
def generate_label_classes_crop_centers(
@@ -639,7 +639,7 @@ def generate_label_classes_crop_centers(
# shift center to range of valid centers
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))
- return ensure_tuple(centers) # type: ignore
+ return ensure_tuple(centers)
def create_grid(
@@ -2218,7 +2218,7 @@ def distance_transform_edt(
if not r_vals:
return None
device = img.device if isinstance(img, torch.Tensor) else None
- return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] # type: ignore
+ return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]
if __name__ == "__main__":
diff --git a/tests/test_inverse.py b/tests/test_inverse.py
index 3f07b43d6d..6bd14a19f1 100644
--- a/tests/test_inverse.py
+++ b/tests/test_inverse.py
@@ -310,7 +310,7 @@
TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], t[3], Compose(Compose(t[4:]))) for t in TESTS]
-TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore
+TESTS = TESTS + TESTS_COMPOSE_X2
NUM_SAMPLES = 5
N_SAMPLES_TESTS = [
diff --git a/tests/utils.py b/tests/utils.py
index cf1711292f..ee800598bb 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -832,9 +832,9 @@ def equal_state_dict(st_1, st_2):
[[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
)
_metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": DEFAULT_TEST_AFFINE})
-TEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore
+TEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS
TEST_NDARRAYS: tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore
-TEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore
+TEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,)
# alias for branch tests
TEST_NDARRAYS_ALL = TEST_NDARRAYS
From 6ec781a0f07b3c3f685ba29e8a340702975f7679 Mon Sep 17 00:00:00 2001
From: elitap
Date: Wed, 22 Nov 2023 06:06:17 +0100
Subject: [PATCH 23/88] =?UTF-8?q?add=20class=20label=20option=20to=20write?=
=?UTF-8?q?=20metric=20report=20to=20improve=20readability=20=E2=80=A6=20(?=
=?UTF-8?q?#7249)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
add class label option to write metric report to improve readability,
without that option in case of many classes the resulting report is very
hard to interpret.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
---------
Signed-off-by: elitap
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/handlers/utils.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py
index 58a3fd36f3..0cd31b89c2 100644
--- a/monai/handlers/utils.py
+++ b/monai/handlers/utils.py
@@ -61,6 +61,7 @@ def write_metrics_reports(
summary_ops: str | Sequence[str] | None,
deli: str = ",",
output_type: str = "csv",
+ class_labels: list[str] | None = None,
) -> None:
"""
Utility function to write the metrics into files, contains 3 parts:
@@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans
deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
+ class_labels: list of class names used to name the classes in the output report, if None,
+ "class0", ..., "classn" are used, default to None.
"""
if output_type.lower() != "csv":
@@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans
v = v.reshape((-1, 1))
# add the average value of all classes to v
- class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"]
+ if class_labels is None:
+ class_labels = ["class" + str(i) for i in range(v.shape[1])]
+ else:
+ class_labels = [str(i) for i in class_labels] # ensure to have a list of str
+
+ class_labels += ["mean"]
v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)
with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:
From a11d8d3cfe66c0ec2fda0a4317da1445b5d3b40c Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Tue, 28 Nov 2023 00:18:51 +0800
Subject: [PATCH 24/88] Fix B026 unrecommanded star-arg unpacking after a
keyword argument (#7262)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7261
### Description
Remove star-arg unpacking before a keyword argument.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: KumoLiu
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/data/image_reader.py | 2 +-
monai/inferers/inferer.py | 2 +-
tests/test_video_datasets.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py
index fe199d9570..0823d11834 100644
--- a/monai/data/image_reader.py
+++ b/monai/data/image_reader.py
@@ -1300,7 +1300,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] |
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
- nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_))
+ nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_))
img_.append(nrrd_image)
return img_ if len(filenames) > 1 else img_[0]
diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py
index bf8c27e5c3..0b4199938d 100644
--- a/monai/inferers/inferer.py
+++ b/monai/inferers/inferer.py
@@ -584,10 +584,10 @@ def __call__(
return super().__call__(
inputs,
network,
+ *args,
device=inputs.device if gpu_stitching else torch.device("cpu"),
buffer_steps=buffer_steps if buffered_stitching else None,
buffer_dim=buffer_dim,
- *args,
**kwargs,
)
except RuntimeError as e:
diff --git a/tests/test_video_datasets.py b/tests/test_video_datasets.py
index eedbe212eb..790feb51ee 100644
--- a/tests/test_video_datasets.py
+++ b/tests/test_video_datasets.py
@@ -39,7 +39,7 @@ def get_video_source(self):
return self.video_source
def get_ds(self, *args, **kwargs) -> VideoDataset:
- return self.ds(video_source=self.get_video_source(), transform=TRANSFORMS, *args, **kwargs) # type: ignore
+ return self.ds(*args, video_source=self.get_video_source(), transform=TRANSFORMS, **kwargs) # type: ignore
@unittest.skipIf(has_cv2, "Only tested when OpenCV not installed.")
def test_no_opencv_raises(self):
From 149a73a9232fd22260e1147e6dc0e4310b199f5a Mon Sep 17 00:00:00 2001
From: ytl0623
Date: Thu, 30 Nov 2023 10:40:00 +0800
Subject: [PATCH 25/88] Quote $PY_EXE variable to deal with Python path that
contain spaces in Bash (#7268)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #5857.
### Description
When dealing with paths that contain spaces in Bash, it's important to
properly quote the variables to ensure that spaces are handled
correctly. So, maybe we can replace all `$PY_EXE` variables to
`"$PY_EXE"` in the `runtests.sh` file.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: ytl0623
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
runtests.sh | 70 ++++++++++++++++++++++++++---------------------------
1 file changed, 35 insertions(+), 35 deletions(-)
diff --git a/runtests.sh b/runtests.sh
index cfceb6976a..0c60bc0f58 100755
--- a/runtests.sh
+++ b/runtests.sh
@@ -119,42 +119,42 @@ function print_usage {
}
# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354
-protobuf_major_version=$(${PY_EXE} -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)
+protobuf_major_version=$("${PY_EXE}" -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)
if [ "$protobuf_major_version" -ge "4" ]
then
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
fi
function check_import {
- echo "Python: ${PY_EXE}"
- ${cmdPrefix}${PY_EXE} -W error -W ignore::DeprecationWarning -W ignore::ResourceWarning -c "import monai"
+ echo "Python: "${PY_EXE}""
+ ${cmdPrefix}"${PY_EXE}" -W error -W ignore::DeprecationWarning -W ignore::ResourceWarning -c "import monai"
}
function print_version {
- ${cmdPrefix}${PY_EXE} -c 'import monai; monai.config.print_config()' # project-monai/monai#6167
+ ${cmdPrefix}"${PY_EXE}" -c 'import monai; monai.config.print_config()' # project-monai/monai#6167
}
function install_deps {
echo "Pip installing MONAI development dependencies and compile MONAI cpp extensions..."
- ${cmdPrefix}${PY_EXE} -m pip install -r requirements-dev.txt
+ ${cmdPrefix}"${PY_EXE}" -m pip install -r requirements-dev.txt
}
function compile_cpp {
echo "Compiling and installing MONAI cpp extensions..."
# depends on setup.py behaviour for building
# currently setup.py uses environment variables: BUILD_MONAI and FORCE_CUDA
- ${cmdPrefix}${PY_EXE} setup.py develop --user --uninstall
+ ${cmdPrefix}"${PY_EXE}" setup.py develop --user --uninstall
if [[ "$OSTYPE" == "darwin"* ]];
then # clang for mac os
- CC=clang CXX=clang++ ${cmdPrefix}${PY_EXE} setup.py develop --user
+ CC=clang CXX=clang++ ${cmdPrefix}"${PY_EXE}" setup.py develop --user
else
- ${cmdPrefix}${PY_EXE} setup.py develop --user
+ ${cmdPrefix}"${PY_EXE}" setup.py develop --user
fi
}
function clang_format {
echo "Running clang-format..."
- ${cmdPrefix}${PY_EXE} -m tests.clang_format_utils
+ ${cmdPrefix}"${PY_EXE}" -m tests.clang_format_utils
clang_format_tool='.clang-format-bin/clang-format'
# Verify .
if ! type -p "$clang_format_tool" >/dev/null; then
@@ -167,19 +167,19 @@ function clang_format {
}
function is_pip_installed() {
- return $(${PY_EXE} -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1)
+ return $("${PY_EXE}" -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1)
}
function clean_py {
if is_pip_installed coverage
then
# remove coverage history
- ${cmdPrefix}${PY_EXE} -m coverage erase
+ ${cmdPrefix}"${PY_EXE}" -m coverage erase
fi
# uninstall the development package
echo "Uninstalling MONAI development files..."
- ${cmdPrefix}${PY_EXE} setup.py develop --user --uninstall
+ ${cmdPrefix}"${PY_EXE}" setup.py develop --user --uninstall
# remove temporary files (in the directory of this script)
TO_CLEAN="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
@@ -201,7 +201,7 @@ function clean_py {
}
function torch_validate {
- ${cmdPrefix}${PY_EXE} -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
+ ${cmdPrefix}"${PY_EXE}" -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
}
function print_error_msg() {
@@ -219,7 +219,7 @@ function print_style_fail_msg() {
}
function list_unittests() {
- ${PY_EXE} - << END
+ "${PY_EXE}" - << END
import unittest
def print_suite(suite):
if hasattr(suite, "__iter__"):
@@ -448,7 +448,7 @@ then
then
install_deps
fi
- ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files
+ ${cmdPrefix}"${PY_EXE}" -m pre_commit run --all-files
pre_commit_status=$?
if [ ${pre_commit_status} -ne 0 ]
@@ -477,13 +477,13 @@ then
then
install_deps
fi
- ${cmdPrefix}${PY_EXE} -m isort --version
+ ${cmdPrefix}"${PY_EXE}" -m isort --version
if [ $doIsortFix = true ]
then
- ${cmdPrefix}${PY_EXE} -m isort "$homedir"
+ ${cmdPrefix}"${PY_EXE}" -m isort "$homedir"
else
- ${cmdPrefix}${PY_EXE} -m isort --check "$homedir"
+ ${cmdPrefix}"${PY_EXE}" -m isort --check "$homedir"
fi
isort_status=$?
@@ -513,13 +513,13 @@ then
then
install_deps
fi
- ${cmdPrefix}${PY_EXE} -m black --version
+ ${cmdPrefix}"${PY_EXE}" -m black --version
if [ $doBlackFix = true ]
then
- ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma "$homedir"
+ ${cmdPrefix}"${PY_EXE}" -m black --skip-magic-trailing-comma "$homedir"
else
- ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma --check "$homedir"
+ ${cmdPrefix}"${PY_EXE}" -m black --skip-magic-trailing-comma --check "$homedir"
fi
black_status=$?
@@ -544,9 +544,9 @@ then
then
install_deps
fi
- ${cmdPrefix}${PY_EXE} -m flake8 --version
+ ${cmdPrefix}"${PY_EXE}" -m flake8 --version
- ${cmdPrefix}${PY_EXE} -m flake8 "$homedir" --count --statistics
+ ${cmdPrefix}"${PY_EXE}" -m flake8 "$homedir" --count --statistics
flake8_status=$?
if [ ${flake8_status} -ne 0 ]
@@ -568,12 +568,12 @@ then
if ! is_pip_installed pylint
then
echo "Pip installing pylint ..."
- ${cmdPrefix}${PY_EXE} -m pip install "pylint>2.16,!=3.0.0"
+ ${cmdPrefix}"${PY_EXE}" -m pip install "pylint>2.16,!=3.0.0"
fi
- ${cmdPrefix}${PY_EXE} -m pylint --version
+ ${cmdPrefix}"${PY_EXE}" -m pylint --version
ignore_codes="C,R,W,E1101,E1102,E0601,E1130,E1123,E0102,E1120,E1137,E1136"
- ${cmdPrefix}${PY_EXE} -m pylint monai tests --disable=$ignore_codes -j $NUM_PARALLEL
+ ${cmdPrefix}"${PY_EXE}" -m pylint monai tests --disable=$ignore_codes -j $NUM_PARALLEL
pylint_status=$?
if [ ${pylint_status} -ne 0 ]
@@ -632,14 +632,14 @@ then
then
install_deps
fi
- pytype_ver=$(${cmdPrefix}${PY_EXE} -m pytype --version)
+ pytype_ver=$(${cmdPrefix}"${PY_EXE}" -m pytype --version)
if [[ "$OSTYPE" == "darwin"* && "$pytype_ver" == "2021."* ]]; then
echo "${red}pytype not working on macOS 2021 (https://github.com/Project-MONAI/MONAI/issues/2391). Please upgrade to 2022*.${noColor}"
exit 1
else
- ${cmdPrefix}${PY_EXE} -m pytype --version
+ ${cmdPrefix}"${PY_EXE}" -m pytype --version
- ${cmdPrefix}${PY_EXE} -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" "$homedir"
+ ${cmdPrefix}"${PY_EXE}" -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" "$homedir"
pytype_status=$?
if [ ${pytype_status} -ne 0 ]
@@ -664,8 +664,8 @@ then
then
install_deps
fi
- ${cmdPrefix}${PY_EXE} -m mypy --version
- ${cmdPrefix}${PY_EXE} -m mypy "$homedir"
+ ${cmdPrefix}"${PY_EXE}" -m mypy --version
+ ${cmdPrefix}"${PY_EXE}" -m mypy "$homedir"
mypy_status=$?
if [ ${mypy_status} -ne 0 ]
@@ -695,7 +695,7 @@ if [ $doMinTests = true ]
then
echo "${separator}${blue}min${noColor}"
doCoverage=false
- ${cmdPrefix}${PY_EXE} -m tests.min_tests
+ ${cmdPrefix}"${PY_EXE}" -m tests.min_tests
fi
# set coverage command
@@ -707,7 +707,7 @@ then
then
install_deps
fi
- cmd="${PY_EXE} -m coverage run --append"
+ cmd=""${PY_EXE}" -m coverage run --append"
fi
# # download test data if needed
@@ -763,6 +763,6 @@ then
then
install_deps
fi
- ${cmdPrefix}${PY_EXE} -m coverage combine --append .coverage/
- ${cmdPrefix}${PY_EXE} -m coverage report --ignore-errors
+ ${cmdPrefix}"${PY_EXE}" -m coverage combine --append .coverage/
+ ${cmdPrefix}"${PY_EXE}" -m coverage report --ignore-errors
fi
From 308e9e27ade95e235023c02ca2ecc334edd9cda6 Mon Sep 17 00:00:00 2001
From: ytl0623
Date: Thu, 30 Nov 2023 12:24:11 +0800
Subject: [PATCH 26/88] add SoftclDiceLoss and SoftDiceclDiceLoss loss function
in documentation (#7271)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7193
### Description
add SoftclDiceLoss and SoftDiceclDiceLoss loss function in
documentation(https://github.com/Project-MONAI/MONAI/blob/dev/docs/source/losses.rst?plain=1)
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: ytl0623
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/losses.rst | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/docs/source/losses.rst b/docs/source/losses.rst
index 5d488afbb3..568c7dfc77 100644
--- a/docs/source/losses.rst
+++ b/docs/source/losses.rst
@@ -78,6 +78,16 @@ Segmentation Losses
.. autoclass:: HausdorffDTLoss
:members:
+`SoftclDiceLoss`
+~~~~~~~~~~~~~~~~
+.. autoclass:: SoftclDiceLoss
+ :members:
+
+`SoftDiceclDiceLoss`
+~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: SoftDiceclDiceLoss
+ :members:
+
Registration Losses
-------------------
From be4a873b6a126e46f0ac8b914e68cdc8e3238029 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 30 Nov 2023 20:10:45 +0800
Subject: [PATCH 27/88] Skip Old Pytorch Versions for `SwinUNETR` (#7266)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7265.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: KumoLiu
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_swin_unetr.py | 46 ++++++++++++++++++++++++----------------
1 file changed, 28 insertions(+), 18 deletions(-)
diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py
index 9f6b1e7c0a..e34e5a3c8e 100644
--- a/tests/test_swin_unetr.py
+++ b/tests/test_swin_unetr.py
@@ -24,13 +24,21 @@
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr
from monai.networks.utils import copy_model_state
from monai.utils import optional_import
-from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config
+from tests.utils import (
+ assert_allclose,
+ pytorch_after,
+ skip_if_downloading_fails,
+ skip_if_no_cuda,
+ skip_if_quick,
+ testing_data_config,
+)
einops, has_einops = optional_import("einops")
TEST_CASE_SWIN_UNETR = []
case_idx = 0
test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2]
+checkpoint_vals = [True, False] if pytorch_after(1, 11) else [False]
for attn_drop_rate in [0.4]:
for in_channels in [1]:
for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]:
@@ -38,23 +46,25 @@
for img_size in ((64, 32, 192), (96, 32)):
for feature_size in [12]:
for norm_name in ["instance"]:
- test_case = [
- {
- "spatial_dims": len(img_size),
- "in_channels": in_channels,
- "out_channels": out_channels,
- "img_size": img_size,
- "feature_size": feature_size,
- "depths": depth,
- "norm_name": norm_name,
- "attn_drop_rate": attn_drop_rate,
- "downsample": test_merging_mode[case_idx % 4],
- },
- (2, in_channels, *img_size),
- (2, out_channels, *img_size),
- ]
- case_idx += 1
- TEST_CASE_SWIN_UNETR.append(test_case)
+ for use_checkpoint in checkpoint_vals:
+ test_case = [
+ {
+ "spatial_dims": len(img_size),
+ "in_channels": in_channels,
+ "out_channels": out_channels,
+ "img_size": img_size,
+ "feature_size": feature_size,
+ "depths": depth,
+ "norm_name": norm_name,
+ "attn_drop_rate": attn_drop_rate,
+ "downsample": test_merging_mode[case_idx % 4],
+ "use_checkpoint": use_checkpoint,
+ },
+ (2, in_channels, *img_size),
+ (2, out_channels, *img_size),
+ ]
+ case_idx += 1
+ TEST_CASE_SWIN_UNETR.append(test_case)
TEST_CASE_FILTER = [
[
From a6e1b71d37103a6273a4f832e354f26c0bd980b6 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Fri, 1 Dec 2023 16:30:30 +0800
Subject: [PATCH 28/88] Bump conda-incubator/setup-miniconda from 2 to 3
(#7274)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[conda-incubator/setup-miniconda](https://github.com/conda-incubator/setup-miniconda)
from 2 to 3.
Release notes
Sourced from conda-incubator/setup-miniconda's
releases.
Version 3.0.0
Features
- #308
Update to node20
- #291
Add conda-solver option (defaults to libmamba)
Fixes
- #299
Fix condaBasePath when useBundled is false, and there's no pre-existing
conda
Documentation
- #309
Switch to main branch based development
- #313
Specify team conda-incubator/setup-miniconda as codeowners
- #318
README: update actions in examples, add security section, similar
actions
Tasks and Maintenance
- #307
Run dependabot against main branch and also update node packages
- #311
Bump actions/checkout from 2 to 4
- #310
Bump actions/cache from 1 to 3
- #314
Strip/update dependencies
- #315
Split lint into check and build, switch from
npm install to
npm ci
- #317
Bump normalize-url from 4.5.1 to 8.0.0
- #316
Faster workflow response / saving resources via timeout/concurrency
policy
#308:
conda-incubator/setup-miniconda#308
#291:
conda-incubator/setup-miniconda#291
#299:
conda-incubator/setup-miniconda#299
#309:
conda-incubator/setup-miniconda#309
#313:
conda-incubator/setup-miniconda#313
#318:
conda-incubator/setup-miniconda#318
#307:
conda-incubator/setup-miniconda#307
#311:
conda-incubator/setup-miniconda#311
#310:
conda-incubator/setup-miniconda#310
#314:
conda-incubator/setup-miniconda#314
#315:
conda-incubator/setup-miniconda#315
#317:
conda-incubator/setup-miniconda#317
#316:
conda-incubator/setup-miniconda#316
New Contributors
Full Changelog: https://github.com/conda-incubator/setup-miniconda/compare/v2...v3.0.0
Version 2.3.0
Documentation
- #263
Update links to GitHub shell docs
... (truncated)
Changelog
Sourced from conda-incubator/setup-miniconda's
changelog.
v3.0.1
(2023-11-29)
Fixes
- #325
Fix environment activation on windows (a v3 regression) due to
hard-coded install PATH
#325:
conda-incubator/setup-miniconda#325
v3.0.0
(2023-11-27)
Features
- #308
Update to node20
- #291
Add conda-solver option (defaults to libmamba)
Fixes
- #299
Fix condaBasePath when useBundled is false, and there's no pre-existing
conda
Documentation
- #309
Switch to main branch based development
- #313
Specify team conda-incubator/setup-miniconda as codeowners
- #318
README: update actions in examples, add security section, similar
actions
Tasks and Maintenance
- #307
Run dependabot against main branch and also update node packages
- #311
Bump actions/checkout from 2 to 4
- #310
Bump actions/cache from 1 to 3
- #314
Strip/update dependencies
- #315
Split lint into check and build, switch from
npm install to
npm ci
- #317
Bump normalize-url from 4.5.1 to 8.0.0
- #316
Faster workflow response / saving resources via timeout/concurrency
policy
#308:
conda-incubator/setup-miniconda#308
#291:
conda-incubator/setup-miniconda#291
#299:
conda-incubator/setup-miniconda#299
#309:
conda-incubator/setup-miniconda#309
#313:
conda-incubator/setup-miniconda#313
#318:
conda-incubator/setup-miniconda#318
#307:
conda-incubator/setup-miniconda#307
#311:
conda-incubator/setup-miniconda#311
#310:
conda-incubator/setup-miniconda#310
... (truncated)
Commits
11b5629
Prepare 3.0.1 (#326)
8706aa7
Fix env activation on win (a v3 regression) due to hard-coded install
PATH (#...
c585a97
Bump conda-incubator/setup-miniconda from 2.3.0 to 3.0.0 (#321)
2defc80
Prepare release (#320)
0d5a56b
Bump actions/checkout from 2 to 4 (#319)
45fd3f9
Merge pull request #316
from dbast/timeout
d1e04fc
Merge pull request #299
from isuruf/condaBasePath
fab0073
Merge pull request #318
from dbast/readme
fa6bdf9
Update with npm run build
d42f8b8
Fix condaBasePath when useBundled is false, and there's no pre-existing
conda
- Additional commits viewable in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/conda.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml
index dc07f833be..a387c77ebd 100644
--- a/.github/workflows/conda.yml
+++ b/.github/workflows/conda.yml
@@ -32,7 +32,7 @@ jobs:
maximum-size: 16GB
disk-root: "D:"
- uses: actions/checkout@v4
- - uses: conda-incubator/setup-miniconda@v2
+ - uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
From dae54786304400ec372b38d367bd7fb98aaef8f8 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Mon, 4 Dec 2023 11:50:01 +0800
Subject: [PATCH 29/88] wholeBody_ct_segmentation failed to be download (#7280)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes model-zoo#537
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: KumoLiu
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/bundle/scripts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index 20a491e493..2565a3cf64 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -221,7 +221,7 @@ def _download_from_ngc(
def _get_latest_bundle_version_monaihosting(name):
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
- full_url = f"{url}/{name}"
+ full_url = f"{url}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
From 98021094d151d3f8b29eb2cdaa500c37448ee474 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Tue, 5 Dec 2023 13:20:01 +0800
Subject: [PATCH 30/88] update the Python version requirements for transformers
(#7275)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Part of #7250.
### Description
Fix the Python version for transformers smaller than 3.10.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: KumoLiu
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/requirements.txt | 2 +-
requirements-dev.txt | 2 +-
setup.cfg | 4 ++--
tests/test_transchex.py | 3 ++-
4 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/docs/requirements.txt b/docs/requirements.txt
index a9bbc384f8..e5bedf8552 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -21,7 +21,7 @@ sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
einops
-transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157
+transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
mlflow>=1.28.0
clearml>=1.10.0rc0
tensorboardX
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 6332d5b0a5..cacbefe234 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin"
pandas
requests
einops
-transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157
+transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
mlflow>=1.28.0
clearml>=1.10.0rc0
matplotlib!=3.5.0
diff --git a/setup.cfg b/setup.cfg
index 123da68dfa..0370d0062d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -65,7 +65,7 @@ all =
imagecodecs
pandas
einops
- transformers<4.22
+ transformers<4.22; python_version <= '3.10'
mlflow>=1.28.0
clearml>=1.10.0rc0
matplotlib
@@ -123,7 +123,7 @@ pandas =
einops =
einops
transformers =
- transformers<4.22
+ transformers<4.22; python_version <= '3.10'
mlflow =
mlflow
matplotlib =
diff --git a/tests/test_transchex.py b/tests/test_transchex.py
index 9ad847cdaa..8fb1f56715 100644
--- a/tests/test_transchex.py
+++ b/tests/test_transchex.py
@@ -18,7 +18,7 @@
from monai.networks import eval_mode
from monai.networks.nets.transchex import Transchex
-from tests.utils import skip_if_quick
+from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick
TEST_CASE_TRANSCHEX = []
for drop_out in [0.4]:
@@ -46,6 +46,7 @@
@skip_if_quick
+@SkipIfAtLeastPyTorchVersion((1, 10))
class TestTranschex(unittest.TestCase):
@parameterized.expand(TEST_CASE_TRANSCHEX)
def test_shape(self, input_param, expected_shape):
From 5fb9f2b6ba08b353d9a2e6083b68663e993ee8b7 Mon Sep 17 00:00:00 2001
From: Kaibo Tang
Date: Tue, 5 Dec 2023 03:46:24 -0500
Subject: [PATCH 31/88] 7263 add diffusion loss (#7272)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7263.
### Description
Add diffusion loss. I also made a [demo
notebook](https://github.com/kvttt/deep-atlas/blob/main/diffusion_loss_scale_test.ipynb)
to provide some explanations and analyses of diffusion loss.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: kaibo
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/losses.rst | 5 ++
monai/losses/__init__.py | 2 +-
monai/losses/deform.py | 82 +++++++++++++++++++++++++
tests/test_diffusion_loss.py | 116 +++++++++++++++++++++++++++++++++++
4 files changed, 204 insertions(+), 1 deletion(-)
create mode 100644 tests/test_diffusion_loss.py
diff --git a/docs/source/losses.rst b/docs/source/losses.rst
index 568c7dfc77..e929e9d605 100644
--- a/docs/source/losses.rst
+++ b/docs/source/losses.rst
@@ -96,6 +96,11 @@ Registration Losses
.. autoclass:: BendingEnergyLoss
:members:
+`DiffusionLoss`
+~~~~~~~~~~~~~~~
+.. autoclass:: DiffusionLoss
+ :members:
+
`LocalNormalizedCrossCorrelationLoss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNormalizedCrossCorrelationLoss
diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py
index d734a9d44d..92898c81ca 100644
--- a/monai/losses/__init__.py
+++ b/monai/losses/__init__.py
@@ -14,7 +14,7 @@
from .adversarial_loss import PatchAdversarialLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
-from .deform import BendingEnergyLoss
+from .deform import BendingEnergyLoss, DiffusionLoss
from .dice import (
Dice,
DiceCELoss,
diff --git a/monai/losses/deform.py b/monai/losses/deform.py
index dd03a8eb3d..129abeedd2 100644
--- a/monai/losses/deform.py
+++ b/monai/losses/deform.py
@@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return energy
+
+
+class DiffusionLoss(_Loss):
+ """
+ Calculate the diffusion based on first-order differentiation of pred using central finite difference.
+ For the original paper, please refer to
+ VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
+ Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
+ IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.
+
+ Adapted from:
+ VoxelMorph (https://github.com/voxelmorph/voxelmorph)
+ """
+
+ def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:
+ """
+ Args:
+ normalize:
+ Whether to divide out spatial sizes in order to make the computation roughly
+ invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
+
+ - ``"none"``: no reduction will be applied.
+ - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
+ - ``"sum"``: the output will be summed.
+ """
+ super().__init__(reduction=LossReduction(reduction).value)
+ self.normalize = normalize
+
+ def forward(self, pred: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ pred:
+ Predicted dense displacement field (DDF) with shape BCH[WD],
+ where C is the number of spatial dimensions.
+ Note that diffusion loss can only be calculated
+ when the sizes of the DDF along all spatial dimensions are greater than 2.
+
+ Raises:
+ ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
+ ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
+ ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
+ ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.
+
+ """
+ if pred.ndim not in [3, 4, 5]:
+ raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
+ for i in range(pred.ndim - 2):
+ if pred.shape[-i - 1] <= 2:
+ raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}")
+ if pred.shape[1] != pred.ndim - 2:
+ raise ValueError(
+ f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, "
+ f"does not match number of spatial dimensions, {pred.ndim - 2}"
+ )
+
+ # first order gradient
+ first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]
+
+ # spatial dimensions in a shape suited for broadcasting below
+ if self.normalize:
+ spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))
+
+ diffusion = torch.tensor(0)
+ for dim_1, g in enumerate(first_order_gradient):
+ dim_1 += 2
+ if self.normalize:
+ # We divide the partial derivative for each vector component at each voxel by the spatial size
+ # corresponding to that component relative to the spatial size of the vector component with respect
+ # to which the partial derivative is taken.
+ g *= pred.shape[dim_1] / spatial_dims
+ diffusion = diffusion + g**2
+
+ if self.reduction == LossReduction.MEAN.value:
+ diffusion = torch.mean(diffusion) # the batch and channel average
+ elif self.reduction == LossReduction.SUM.value:
+ diffusion = torch.sum(diffusion) # sum over the batch and channel dims
+ elif self.reduction != LossReduction.NONE.value:
+ raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
+
+ return diffusion
diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py
new file mode 100644
index 0000000000..05dfab95fb
--- /dev/null
+++ b/tests/test_diffusion_loss.py
@@ -0,0 +1,116 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.losses.deform import DiffusionLoss
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+TEST_CASES = [
+ # all first partials are zero, so the diffusion loss is also zero
+ [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
+ # all first partials are one, so the diffusion loss is also one
+ [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],
+ # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67
+ [
+ {"normalize": False},
+ {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
+ 56.0 / 3.0,
+ ],
+ # same as the previous case
+ [
+ {"normalize": False},
+ {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
+ 56.0 / 3.0,
+ ],
+ # same as the previous case
+ [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
+ # we have shown in the demo notebook that
+ # diffusion loss is scale-invariant when the all axes have the same resolution
+ [
+ {"normalize": True},
+ {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
+ 56.0 / 3.0,
+ ],
+ [
+ {"normalize": True},
+ {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
+ 56.0 / 3.0,
+ ],
+ [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
+ # for the following case, consider the following 2D matrix:
+ # tensor([[[[0, 1, 2],
+ # [1, 2, 3],
+ # [2, 3, 4],
+ # [3, 4, 5],
+ # [4, 5, 6]],
+ # [[0, 1, 2],
+ # [1, 2, 3],
+ # [2, 3, 4],
+ # [3, 4, 5],
+ # [4, 5, 6]]]])
+ # the first partials wrt x are all ones, and so are the first partials wrt y
+ # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2
+ [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],
+ # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,
+ # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y
+ # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689
+ [
+ {"normalize": True},
+ {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},
+ (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,
+ ],
+]
+
+
+class TestDiffusionLoss(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_shape(self, input_param, input_data, expected_val):
+ result = DiffusionLoss(**input_param).forward(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
+
+ def test_ill_shape(self):
+ loss = DiffusionLoss()
+ # not in 3-d, 4-d, 5-d
+ with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
+ loss.forward(torch.ones((1, 3), device=device))
+ with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
+ loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
+ with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
+ loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
+ with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
+ loss.forward(torch.ones((1, 3, 5, 2, 5)))
+ with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
+ loss.forward(torch.ones((1, 3, 5, 5, 2)))
+
+ # number of vector components unequal to number of spatial dims
+ with self.assertRaisesRegex(ValueError, "Number of vector components"):
+ loss.forward(torch.ones((1, 2, 5, 5, 5)))
+ with self.assertRaisesRegex(ValueError, "Number of vector components"):
+ loss.forward(torch.ones((1, 2, 5, 5, 5)))
+
+ def test_ill_opts(self):
+ pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
+ with self.assertRaisesRegex(ValueError, ""):
+ DiffusionLoss(reduction="unknown")(pred)
+ with self.assertRaisesRegex(ValueError, ""):
+ DiffusionLoss(reduction=None)(pred)
+
+
+if __name__ == "__main__":
+ unittest.main()
From 6a45df237730e369c10e184a3dc747c57c9d098a Mon Sep 17 00:00:00 2001
From: Yufan He <59374597+heyufan1995@users.noreply.github.com>
Date: Thu, 7 Dec 2023 21:36:21 -0500
Subject: [PATCH 32/88] Fix swinunetrv2 2D bug (#7302)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes # .
### Description
A few sentences describing the changes proposed in this pull request.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: heyufan1995
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/networks/nets/swin_unetr.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py
index 10c4ce3d8e..6f96dfd291 100644
--- a/monai/networks/nets/swin_unetr.py
+++ b/monai/networks/nets/swin_unetr.py
@@ -1024,7 +1024,7 @@ def __init__(
self.layers4.append(layer)
if self.use_v2:
layerc = UnetrBasicBlock(
- spatial_dims=3,
+ spatial_dims=spatial_dims,
in_channels=embed_dim * 2**i_layer,
out_channels=embed_dim * 2**i_layer,
kernel_size=3,
From fc82d506b5090601bc336d71ee23377c35e6481a Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Tue, 12 Dec 2023 11:07:04 +0800
Subject: [PATCH 33/88] Fix `RuntimeError` in `DataAnalyzer` (#7310)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7309
### Description
`DataAnalyzer` only catch error when data is on GPU, add catching error
when data is on CPU.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/auto3dseg/data_analyzer.py | 26 ++++++++++++++++----------
monai/auto3dseg/analyzer.py | 2 +-
2 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py
index 9280fb5be5..15e56abfea 100644
--- a/monai/apps/auto3dseg/data_analyzer.py
+++ b/monai/apps/auto3dseg/data_analyzer.py
@@ -28,7 +28,7 @@
from monai.data import DataLoader, Dataset, partition_dataset
from monai.data.utils import no_collation
from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd
-from monai.utils import StrEnum, min_version, optional_import
+from monai.utils import ImageMetaKey, StrEnum, min_version, optional_import
from monai.utils.enums import DataStatsKeys, ImageStatsKeys
@@ -343,19 +343,25 @@ def _get_all_case_stats(
d = summarizer(batch_data)
except BaseException as err:
if "image_meta_dict" in batch_data.keys():
- filename = batch_data["image_meta_dict"]["filename_or_obj"]
+ filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ]
else:
- filename = batch_data[self.image_key].meta["filename_or_obj"]
+ filename = batch_data[self.image_key].meta[ImageMetaKey.FILENAME_OR_OBJ]
logger.info(f"Unable to process data {filename} on {device}. {err}")
if self.device.type == "cuda":
logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.")
- batch_data[self.image_key] = batch_data[self.image_key].to("cpu")
- if self.label_key is not None:
- label = batch_data[self.label_key]
- if not _label_argmax:
- label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
- batch_data[self.label_key] = label.to("cpu")
- d = summarizer(batch_data)
+ try:
+ batch_data[self.image_key] = batch_data[self.image_key].to("cpu")
+ if self.label_key is not None:
+ label = batch_data[self.label_key]
+ if not _label_argmax:
+ label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
+ batch_data[self.label_key] = label.to("cpu")
+ d = summarizer(batch_data)
+ except BaseException as err:
+ logger.info(f"Unable to process data {filename} on {device}. {err}")
+ continue
+ else:
+ continue
stats_by_cases = {
DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py
index 654999d439..d5cfb21dab 100644
--- a/monai/auto3dseg/analyzer.py
+++ b/monai/auto3dseg/analyzer.py
@@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
torch.set_grad_enabled(False)
ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
- ndas_label: MetaTensor = d[self.label_key] # (H,W,D)
+ ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D)
if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
From 210b23ab87c70c02c84207531b641922e5e45dae Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 14 Dec 2023 10:14:19 +0800
Subject: [PATCH 34/88] Support specified filenames in `Saveimage` (#7318)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7317
### Description
Add support specified filename for users to save like nibabel.save.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/io/array.py | 17 ++++++++++++++---
tests/test_save_image.py | 16 ++++++++++++++++
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py
index cd7e4ef090..7222a26fc3 100644
--- a/monai/transforms/io/array.py
+++ b/monai/transforms/io/array.py
@@ -414,6 +414,9 @@ def __init__(
self.fname_formatter = output_name_formatter
self.output_ext = output_ext.lower() or output_format.lower()
+ self.output_ext = (
+ f".{self.output_ext}" if self.output_ext and not self.output_ext.startswith(".") else self.output_ext
+ )
if isinstance(writer, str):
writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in
if not has_built_in:
@@ -458,15 +461,23 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ
self.write_kwargs.update(write_kwargs)
return self
- def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None):
+ def __call__(
+ self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None, filename: str | PathLike | None = None
+ ):
"""
Args:
img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`.
meta_data: key-value pairs of metadata corresponding to the data.
+ filename: str or file-like object which to save img.
+ If specified, will ignore `self.output_name_formatter` and `self.folder_layout`.
"""
meta_data = img.meta if isinstance(img, MetaTensor) else meta_data
- kw = self.fname_formatter(meta_data, self)
- filename = self.folder_layout.filename(**kw)
+ if filename is not None:
+ filename = f"{filename}{self.output_ext}"
+ else:
+ kw = self.fname_formatter(meta_data, self)
+ filename = self.folder_layout.filename(**kw)
+
if meta_data:
meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ()))
if len(meta_spatial_shape) >= len(img.shape):
diff --git a/tests/test_save_image.py b/tests/test_save_image.py
index ba94ab5087..d88db201ce 100644
--- a/tests/test_save_image.py
+++ b/tests/test_save_image.py
@@ -37,6 +37,8 @@
False,
]
+TEST_CASE_5 = [torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8), ".dcm", False]
+
@unittest.skipUnless(has_itk, "itk not installed")
class TestSaveImage(unittest.TestCase):
@@ -58,6 +60,20 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample):
filepath = "testfile0" if meta_data is not None else "0"
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + "_trans" + output_ext)))
+ @parameterized.expand([TEST_CASE_5])
+ def test_saved_content_with_filename(self, test_data, output_ext, resample):
+ with tempfile.TemporaryDirectory() as tempdir:
+ trans = SaveImage(
+ output_dir=tempdir,
+ output_ext=output_ext,
+ resample=resample,
+ separate_folder=False, # test saving into the same folder
+ )
+ filename = str(os.path.join(tempdir, "test"))
+ trans(test_data, filename=filename)
+
+ self.assertTrue(os.path.exists(filename + output_ext))
+
if __name__ == "__main__":
unittest.main()
From 2fb60c156d5287139ab066195c2f22f1077d0dcd Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 15 Dec 2023 11:32:18 +0800
Subject: [PATCH 35/88] Fix typo (#7321)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fix typo.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/metrics/hausdorff_distance.py | 2 +-
monai/metrics/surface_dice.py | 2 +-
monai/metrics/surface_distance.py | 2 +-
monai/metrics/utils.py | 6 +++---
4 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py
index d9bbf17db3..d727eb0567 100644
--- a/monai/metrics/hausdorff_distance.py
+++ b/monai/metrics/hausdorff_distance.py
@@ -190,7 +190,7 @@ def compute_hausdorff_distance(
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
- symetric=not directed,
+ symmetric=not directed,
class_index=c,
)
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py
index 635eb1bc24..b20b47a1a5 100644
--- a/monai/metrics/surface_dice.py
+++ b/monai/metrics/surface_dice.py
@@ -253,7 +253,7 @@ def compute_surface_dice(
distance_metric=distance_metric,
spacing=spacing_list[b],
use_subvoxels=use_subvoxels,
- symetric=True,
+ symmetric=True,
class_index=c,
)
boundary_correct: int | torch.Tensor | float
diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py
index 7ce632c588..3cb336d6a0 100644
--- a/monai/metrics/surface_distance.py
+++ b/monai/metrics/surface_distance.py
@@ -177,7 +177,7 @@ def compute_average_surface_distance(
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
- symetric=symmetric,
+ symmetric=symmetric,
class_index=c,
)
surface_distance = torch.cat(distances)
diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py
index 62e6520b96..d4b8f6e9b6 100644
--- a/monai/metrics/utils.py
+++ b/monai/metrics/utils.py
@@ -295,7 +295,7 @@ def get_edge_surface_distance(
distance_metric: str = "euclidean",
spacing: int | float | np.ndarray | Sequence[int | float] | None = None,
use_subvoxels: bool = False,
- symetric: bool = False,
+ symmetric: bool = False,
class_index: int = -1,
) -> tuple[
tuple[torch.Tensor, torch.Tensor],
@@ -314,7 +314,7 @@ def get_edge_surface_distance(
See :py:func:`monai.metrics.utils.get_surface_distance`.
use_subvoxels: whether to use subvoxel resolution (using the spacing).
This will return the areas of the edges.
- symetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`.
+ symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`.
class_index: The class-index used for context when warning about empty ground truth or prediction.
Returns:
@@ -338,7 +338,7 @@ def get_edge_surface_distance(
" this may result in nan/inf distance."
)
distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]
- if symetric:
+ if symmetric:
distances = (
get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),
get_surface_distance(edges_gt, edges_pred, distance_metric, spacing),
From 275d51f69dd55eceb0f4f10e666ba6af083f95ac Mon Sep 17 00:00:00 2001
From: binliunls <107988372+binliunls@users.noreply.github.com>
Date: Fri, 15 Dec 2023 22:00:24 +0800
Subject: [PATCH 36/88] fix optimizer pararmeter issue (#7322)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes # .
### Description
A few sentences describing the changes proposed in this pull request.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: binliu
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/handlers/mlflow_handler.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py
index a2bd345dc6..df209c1c8b 100644
--- a/monai/handlers/mlflow_handler.py
+++ b/monai/handlers/mlflow_handler.py
@@ -401,7 +401,7 @@ def _default_iteration_log(self, engine: Engine) -> None:
cur_optimizer = engine.optimizer
for param_name in self.optimizer_param_names:
params = {
- f"{param_name} group_{i}": float(param_group[param_name])
+ f"{param_name}_group_{i}": float(param_group[param_name])
for i, param_group in enumerate(cur_optimizer.param_groups)
}
self._log_metrics(params, step=engine.state.iteration)
From 469db7ae45ca169666b3a729f96dd78d7fd929ca Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Mon, 18 Dec 2023 12:00:43 +0800
Subject: [PATCH 37/88] Fix `lazy` ignored in `SpatialPadd` (#7316)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7314 #7315.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Ben Murray
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/croppad/dictionary.py | 9 +++------
tests/padders.py | 3 +++
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py
index 56d214c51d..be9441dc4a 100644
--- a/monai/transforms/croppad/dictionary.py
+++ b/monai/transforms/croppad/dictionary.py
@@ -221,9 +221,8 @@ def __init__(
note that `np.pad` treats channel dimension as the first dimension.
"""
- LazyTransform.__init__(self, lazy)
padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs)
- Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys)
+ Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy)
class BorderPadd(Padd):
@@ -274,9 +273,8 @@ def __init__(
note that `np.pad` treats channel dimension as the first dimension.
"""
- LazyTransform.__init__(self, lazy)
padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs)
- Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys)
+ Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy)
class DivisiblePadd(Padd):
@@ -324,9 +322,8 @@ def __init__(
See also :py:class:`monai.transforms.SpatialPad`
"""
- LazyTransform.__init__(self, lazy)
padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs)
- Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys)
+ Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy)
class Cropd(MapTransform, InvertibleTransform, LazyTransform):
diff --git a/tests/padders.py b/tests/padders.py
index 02d7b40af6..ae1153bdfd 100644
--- a/tests/padders.py
+++ b/tests/padders.py
@@ -136,6 +136,9 @@ def pad_test_pending_ops(self, input_param, input_shape):
# TODO: mode="bilinear" may report error
overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False}
result = apply_pending(pending_result, overrides=overrides)[0]
+ # lazy in constructor
+ pad_fn_lazy = self.Padder(mode=mode[0], lazy=True, **input_param)
+ self.assertTrue(pad_fn_lazy.lazy)
# compare
assert_allclose(result, expected, rtol=1e-5)
if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform):
From 3f3e03c702868880915937a87ac31bbfa14701c8 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 28 Dec 2023 22:22:52 +0800
Subject: [PATCH 38/88] Update openslide-python version (#7344)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
requirements-dev.txt | 2 +-
setup.cfg | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index cacbefe234..2639c0a3e7 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -27,7 +27,7 @@ ninja
torchvision
psutil
cucim>=23.2.0; platform_system == "Linux"
-openslide-python==1.1.2
+openslide-python
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
tifffile; platform_system == "Linux" or platform_system == "Darwin"
pandas
diff --git a/setup.cfg b/setup.cfg
index 0370d0062d..b38974a1a4 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -60,7 +60,7 @@ all =
lmdb
psutil
cucim>=23.2.0
- openslide-python==1.1.2
+ openslide-python
tifffile
imagecodecs
pandas
@@ -113,7 +113,7 @@ psutil =
cucim =
cucim>=23.2.0
openslide =
- openslide-python==1.1.2
+ openslide-python
tifffile =
tifffile
imagecodecs =
From 71d838fc6760676bb1cfdad9f9b0b5db209a3607 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 29 Dec 2023 12:33:46 +0800
Subject: [PATCH 39/88] Upgrade the version of `transformers` (#7343)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7338
### Description
transformers' version is pinned to v4.22 since
https://github.com/Project-MONAI/MONAI/issues/5157.
Updated the version refer to
https://github.com/huggingface/transformers/issues/21678.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/networks/nets/transchex.py | 49 +++++++++-----------------------
requirements-dev.txt | 2 +-
tests/test_transchex.py | 3 +-
3 files changed, 15 insertions(+), 39 deletions(-)
diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py
index ff27903cef..6bfff3c956 100644
--- a/monai/networks/nets/transchex.py
+++ b/monai/networks/nets/transchex.py
@@ -12,20 +12,17 @@
from __future__ import annotations
import math
-import os
-import shutil
-import tarfile
-import tempfile
from collections.abc import Sequence
import torch
from torch import nn
+from monai.config.type_definitions import PathLike
from monai.utils import optional_import
transformers = optional_import("transformers")
load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0]
-cached_path = optional_import("transformers.file_utils", name="cached_path")[0]
+cached_file = optional_import("transformers.utils", name="cached_file")[0]
BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0]
BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0]
@@ -63,44 +60,16 @@ def from_pretrained(
state_dict=None,
cache_dir=None,
from_tf=False,
+ path_or_repo_id="bert-base-uncased",
+ filename="pytorch_model.bin",
*inputs,
**kwargs,
):
- archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
- resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
- tempdir = None
- if os.path.isdir(resolved_archive_file) or from_tf:
- serialization_dir = resolved_archive_file
- else:
- tempdir = tempfile.mkdtemp()
- with tarfile.open(resolved_archive_file, "r:gz") as archive:
-
- def is_within_directory(directory, target):
- abs_directory = os.path.abspath(directory)
- abs_target = os.path.abspath(target)
-
- prefix = os.path.commonprefix([abs_directory, abs_target])
-
- return prefix == abs_directory
-
- def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
- for member in tar.getmembers():
- member_path = os.path.join(path, member.name)
- if not is_within_directory(path, member_path):
- raise Exception("Attempted Path Traversal in Tar File")
-
- tar.extractall(path, members, numeric_owner=numeric_owner)
-
- safe_extract(archive, tempdir)
- serialization_dir = tempdir
+ weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
if state_dict is None and not from_tf:
- weights_path = os.path.join(serialization_dir, "pytorch_model.bin")
state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
- if tempdir:
- shutil.rmtree(tempdir)
if from_tf:
- weights_path = os.path.join(serialization_dir, "model.ckpt")
return load_tf_weights_in_bert(model, weights_path)
old_keys = []
new_keys = []
@@ -304,6 +273,8 @@ def __init__(
chunk_size_feed_forward: int = 0,
is_decoder: bool = False,
add_cross_attention: bool = False,
+ path_or_repo_id: str | PathLike = "bert-base-uncased",
+ filename: str = "pytorch_model.bin",
) -> None:
"""
Args:
@@ -315,6 +286,10 @@ def __init__(
num_vision_layers: number of vision transformer layers.
num_mixed_layers: number of mixed transformer layers.
drop_out: fraction of the input units to drop.
+ path_or_repo_id: This can be either:
+ - a string, the *model id* of a model repo on huggingface.co.
+ - a path to a *directory* potentially containing the file.
+ filename: The name of the file to locate in `path_or_repo`.
The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.
@@ -369,6 +344,8 @@ def __init__(
num_vision_layers=num_vision_layers,
num_mixed_layers=num_mixed_layers,
bert_config=bert_config,
+ path_or_repo_id=path_or_repo_id,
+ filename=filename,
)
self.patch_size = patch_size
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 2639c0a3e7..4685cd1572 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin"
pandas
requests
einops
-transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
+transformers>=4.36.0
mlflow>=1.28.0
clearml>=1.10.0rc0
matplotlib!=3.5.0
diff --git a/tests/test_transchex.py b/tests/test_transchex.py
index 8fb1f56715..9ad847cdaa 100644
--- a/tests/test_transchex.py
+++ b/tests/test_transchex.py
@@ -18,7 +18,7 @@
from monai.networks import eval_mode
from monai.networks.nets.transchex import Transchex
-from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick
+from tests.utils import skip_if_quick
TEST_CASE_TRANSCHEX = []
for drop_out in [0.4]:
@@ -46,7 +46,6 @@
@skip_if_quick
-@SkipIfAtLeastPyTorchVersion((1, 10))
class TestTranschex(unittest.TestCase):
@parameterized.expand(TEST_CASE_TRANSCHEX)
def test_shape(self, input_param, expected_shape):
From 80ed15f3f818db399a82bab986a4ceb8db54054d Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 2 Jan 2024 03:34:50 +0000
Subject: [PATCH 40/88] Bump github/codeql-action from 2 to 3 (#7354)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps [github/codeql-action](https://github.com/github/codeql-action)
from 2 to 3.
Release notes
Sourced from github/codeql-action's
releases.
CodeQL Bundle v2.15.5
Bundles CodeQL CLI v2.15.5
Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.5:
codeql/cpp-queries (changelog,
source)
codeql/cpp-all (changelog,
source)
codeql/csharp-queries (changelog,
source)
codeql/csharp-all (changelog,
source)
codeql/go-queries (changelog,
source)
codeql/go-all (changelog,
source)
codeql/java-queries (changelog,
source)
codeql/java-all (changelog,
source)
codeql/javascript-queries (changelog,
source)
codeql/javascript-all (changelog,
source)
codeql/python-queries (changelog,
source)
codeql/python-all (changelog,
source)
codeql/ruby-queries (changelog,
source)
codeql/ruby-all (changelog,
source)
codeql/swift-queries (changelog,
source)
codeql/swift-all (changelog,
source)
CodeQL Bundle v2.15.4
Bundles CodeQL CLI v2.15.4
Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.4:
codeql/cpp-queries (changelog,
source)
codeql/cpp-all (changelog,
source)
codeql/csharp-queries (changelog,
source)
codeql/csharp-all (changelog,
source)
codeql/go-queries (changelog,
source)
codeql/go-all (changelog,
source)
codeql/java-queries (changelog,
source)
codeql/java-all (changelog,
source)
codeql/javascript-queries (changelog,
source)
codeql/javascript-all (changelog,
source)
codeql/python-queries (changelog,
source)
codeql/python-all (changelog,
source)
codeql/ruby-queries (changelog,
source)
codeql/ruby-all (changelog,
source)
codeql/swift-queries (changelog,
source)
codeql/swift-all (changelog,
source)
CodeQL Bundle
Bundles CodeQL CLI v2.15.3
Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.3:
... (truncated)
Changelog
Sourced from github/codeql-action's
changelog.
Commits
e0c2b0a
change version numbers inside processing function as well
8e4a6c7
improve handling of changelog processing for backports
511f073
Merge pull request #2033
from github/dependabot/npm_and_yarn/npm-0a98872b3d
ebf5a83
Merge pull request #2035
from github/mergeback/v3.22.11-to-main-b374143c
7813bda
Update checked-in dependencies
2b2fb6b
Update changelog and version after v3.22.11
b374143
Merge pull request #2034
from github/update-v3.22.11-64e61baea
95591ba
Merge branch 'main' into dependabot/npm_and_yarn/npm-0a98872b3d
e2b5cc7
Update changelog for v3.22.11
- See full diff in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/codeql-analysis.yml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
index 3d32ae407a..18f1519b5a 100644
--- a/.github/workflows/codeql-analysis.yml
+++ b/.github/workflows/codeql-analysis.yml
@@ -42,7 +42,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
- uses: github/codeql-action/init@v2
+ uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -72,4 +72,4 @@ jobs:
BUILD_MONAI=1 ./runtests.sh --build
- name: Perform CodeQL Analysis
- uses: github/codeql-action/analyze@v2
+ uses: github/codeql-action/analyze@v3
From c210768301af44d152e053e4a5db0e4aac3b90f9 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 2 Jan 2024 06:42:41 +0000
Subject: [PATCH 41/88] Bump actions/upload-artifact from 3 to 4 (#7350)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[actions/upload-artifact](https://github.com/actions/upload-artifact)
from 3 to 4.
Release notes
Sourced from actions/upload-artifact's
releases.
v4.0.0
What's Changed
The release of upload-artifact@v4 and download-artifact@v4 are major
changes to the backend architecture of Artifacts. They have numerous
performance and behavioral improvements.
For more information, see the @actions/artifact
documentation.
New Contributors
Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v4.0.0
v3.1.3
What's Changed
Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v3.1.3
v3.1.2
- Update all
@actions/* NPM packages to their latest
versions- #374
- Update all dev dependencies to their most recent versions - #375
v3.1.1
- Update actions/core package to latest version to remove
set-output deprecation warning #351
v3.1.0
What's Changed
Commits
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/docker.yml | 2 +-
.github/workflows/release.yml | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index f51e4fdf76..f80a4c2c96 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -36,7 +36,7 @@ jobs:
python setup.py build
cat build/lib/monai/_version.py
- name: Upload version
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: _version.py
path: build/lib/monai/_version.py
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 7197215486..e9817e1c4c 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -66,7 +66,7 @@ jobs:
- if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')
name: Upload artifacts
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: dist
path: dist/
@@ -108,7 +108,7 @@ jobs:
python setup.py build
cat build/lib/monai/_version.py
- name: Upload version
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: _version.py
path: build/lib/monai/_version.py
From 285dcfce71bbd8243cc28bf4ca7173f9c22912dd Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 2 Jan 2024 17:29:09 +0800
Subject: [PATCH 42/88] Bump actions/setup-python from 4 to 5 (#7351)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps [actions/setup-python](https://github.com/actions/setup-python)
from 4 to 5.
Release notes
Sourced from actions/setup-python's
releases.
v5.0.0
What's Changed
In scope of this release, we update node version runtime from node16
to node20 (actions/setup-python#772).
Besides, we update dependencies to the latest versions.
Full Changelog: https://github.com/actions/setup-python/compare/v4.8.0...v5.0.0
v4.8.0
What's Changed
In scope of this release we added support for GraalPy (actions/setup-python#694).
You can use this snippet to set up GraalPy:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 'graalpy-22.3'
- run: python my_script.py
Besides, the release contains such changes as:
New Contributors
Full Changelog: https://github.com/actions/setup-python/compare/v4...v4.8.0
v4.7.1
What's Changed
Full Changelog: https://github.com/actions/setup-python/compare/v4...v4.7.1
v4.7.0
In scope of this release, the support for reading python version from
pyproject.toml was added (actions/setup-python#669).
- name: Setup Python
uses: actions/setup-python@v4
</tr></table>
... (truncated)
Commits
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/cron-ngc-bundle.yml | 2 +-
.github/workflows/docker.yml | 2 +-
.github/workflows/pythonapp-min.yml | 6 +++---
.github/workflows/pythonapp.yml | 8 ++++----
.github/workflows/release.yml | 4 ++--
.github/workflows/setupapp.yml | 4 ++--
.github/workflows/weekly-preview.yml | 2 +-
7 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml
index 0bba630d03..84666204a9 100644
--- a/.github/workflows/cron-ngc-bundle.yml
+++ b/.github/workflows/cron-ngc-bundle.yml
@@ -19,7 +19,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: cache weekly timestamp
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index f80a4c2c96..c375e82e74 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -26,7 +26,7 @@ jobs:
ref: dev
fetch-depth: 0
- name: Set up Python 3.9
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.9'
- shell: bash
diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml
index 558c270e33..7b7930bdf5 100644
--- a/.github/workflows/pythonapp-min.yml
+++ b/.github/workflows/pythonapp-min.yml
@@ -30,7 +30,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: Prepare pip wheel
@@ -76,7 +76,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Prepare pip wheel
@@ -121,7 +121,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: Prepare pip wheel
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index ad8b555dd4..29a79759e0 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -28,7 +28,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: cache weekly timestamp
@@ -69,7 +69,7 @@ jobs:
disk-root: "D:"
- uses: actions/checkout@v4
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: Prepare pip wheel
@@ -128,7 +128,7 @@ jobs:
with:
fetch-depth: 0
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: cache weekly timestamp
@@ -209,7 +209,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: cache weekly timestamp
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index e9817e1c4c..9334908bfc 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -19,7 +19,7 @@ jobs:
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install setuptools
@@ -97,7 +97,7 @@ jobs:
with:
fetch-depth: 0
- name: Set up Python 3.9
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.9'
- shell: bash
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index 0ff7162bee..82394a86dd 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -83,7 +83,7 @@ jobs:
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: cache weekly timestamp
@@ -120,7 +120,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Python 3.8
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: cache weekly timestamp
diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml
index c631982745..e94e1dac5a 100644
--- a/.github/workflows/weekly-preview.yml
+++ b/.github/workflows/weekly-preview.yml
@@ -14,7 +14,7 @@ jobs:
ref: dev
fetch-depth: 0
- name: Set up Python 3.9
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Install setuptools
From a6c83d0ebf67550becbb315d75efea28a7c9adc1 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 2 Jan 2024 23:42:25 +0800
Subject: [PATCH 43/88] Bump actions/download-artifact from 3 to 4 (#7352)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[actions/download-artifact](https://github.com/actions/download-artifact)
from 3 to 4.
Release notes
Sourced from actions/download-artifact's
releases.
v4.0.0
What's Changed
The release of upload-artifact@v4 and download-artifact@v4 are major
changes to the backend architecture of Artifacts. They have numerous
performance and behavioral improvements.
For more information, see the @actions/artifact
documentation.
New Contributors
Full Changelog: https://github.com/actions/download-artifact/compare/v3...v4.0.0
v3.0.2
- Bump
@actions/artifact to v1.1.1 - actions/download-artifact#195
- Fixed a bug in Node16 where if an HTTP download finished too quickly
(<1ms, e.g. when it's mocked) we attempt to delete a temp file that
has not been created yet actions/toolkit#1278
v3.0.1
Commits
f44cd7b
Merge pull request #259
from actions/robherley/glob-downloads
3181fe8
add some migration docs
aaaac7b
licensed cache
7c9182f
update readme
b94e701
licensed cache
0b55470
add test case for globbed downloads to same directory
0b51c2e
update prettier/eslint versions
c4c6db7
support globbing artifact list & merging download directory
1bd0606
Merge pull request #252
from stchr/patch-1
eff4d42
fix default for run-id
- Additional commits viewable in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/docker.yml | 2 +-
.github/workflows/release.yml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index c375e82e74..229ae675f5 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -56,7 +56,7 @@ jobs:
with:
ref: dev
- name: Download version
- uses: actions/download-artifact@v3
+ uses: actions/download-artifact@v4
with:
name: _version.py
- name: docker_build
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 9334908bfc..a03d2cea6c 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -125,7 +125,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Download version
- uses: actions/download-artifact@v3
+ uses: actions/download-artifact@v4
with:
name: _version.py
- name: Set tag
From ac56b50a45620617dfbd44140fb70a74120bf47c Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Wed, 3 Jan 2024 10:41:55 +0800
Subject: [PATCH 44/88] Bump peter-evans/slash-command-dispatch from 3.0.1 to
3.0.2 (#7353)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[peter-evans/slash-command-dispatch](https://github.com/peter-evans/slash-command-dispatch)
from 3.0.1 to 3.0.2.
Release notes
Sourced from peter-evans/slash-command-dispatch's
releases.
Slash Command Dispatch v3.0.2
What's Changed
New Contributors
Full Changelog: https://github.com/peter-evans/slash-command-dispatch/compare/v3.0.1...v3.0.2
Commits
f996d7b
Fix the CollaboratorPermission GraphQL query (#301)
05b97d6
build(deps-dev): bump @types/node from 16.18.65 to
16.18.67 (#300)
8e70073
build(deps-dev): bump eslint from 8.54.0 to 8.55.0 (#299)
bd00135
build(deps-dev): bump @types/node from 16.18.62 to
16.18.65 (#298)
ee873b6
build(deps-dev): bump eslint from 8.53.0 to 8.54.0 (#296)
44abc47
build(deps-dev): bump @types/node from 16.18.61 to
16.18.62 (#295)
19ad7b8
build(deps-dev): bump @types/node from 16.18.60 to
16.18.61 (#294)
29a9815
build(deps-dev): bump prettier from 3.0.3 to 3.1.0 (#293)
ade0309
build(deps-dev): bump eslint from 8.52.0 to 8.53.0 (#292)
fc8222e
build(deps-dev): bump @types/node from 16.18.59 to
16.18.60 (#291)
- Additional commits viewable in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/chatops.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml
index b4e201a0d9..59c7d070b4 100644
--- a/.github/workflows/chatops.yml
+++ b/.github/workflows/chatops.yml
@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: dispatch
- uses: peter-evans/slash-command-dispatch@v3.0.1
+ uses: peter-evans/slash-command-dispatch@v3.0.2
with:
token: ${{ secrets.PR_MAINTAIN }}
reaction-token: ${{ secrets.GITHUB_TOKEN }}
From 2fa5bf1c935eca7e385084d3152bfdef4ffacf0f Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Mon, 8 Jan 2024 11:18:45 +0800
Subject: [PATCH 45/88] Give more useful exception when batch is considered
during matrix multiplication (#7326)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7323
### Description
Give more useful exception when batch is considered during matrix
multiplication.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/inverse.py | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py
index 41fabb35aa..f94f11eca9 100644
--- a/monai/transforms/inverse.py
+++ b/monai/transforms/inverse.py
@@ -185,7 +185,17 @@ def track_transform_meta(
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
orig_affine = data_t.peek_pending_affine()
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
- affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
+ try:
+ affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
+ except RuntimeError as e:
+ if orig_affine.ndim > 2:
+ if data_t.is_batch:
+ msg = "Transform applied to batched tensor, should be applied to instances only"
+ else:
+ msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation."
+ raise RuntimeError(msg) from e
+ else:
+ raise
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
From 23ab35c4ec0a7c0f926257f835c4a651a41628bf Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Tue, 9 Jan 2024 11:13:14 +0800
Subject: [PATCH 46/88] Fix incorrectly size compute in auto3dseg analyzer
(#7374)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7222
### Description
remove int convert here.
https://github.com/Project-MONAI/MONAI/blob/8fa6931b14ba9617a595fff1d396ac44cc82e207/monai/auto3dseg/analyzer.py#L259
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/auto3dseg/analyzer.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py
index d5cfb21dab..56419da4cb 100644
--- a/monai/auto3dseg/analyzer.py
+++ b/monai/auto3dseg/analyzer.py
@@ -256,7 +256,7 @@ def __call__(self, data):
)
report[ImageStatsKeys.SIZEMM] = [
- int(a * b) for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
+ a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]
report[ImageStatsKeys.INTENSITY] = [
From 7cc7c9bfedf17553f3a4118f742b930315d31d85 Mon Sep 17 00:00:00 2001
From: Kaibo Tang
Date: Tue, 9 Jan 2024 22:41:30 -0500
Subject: [PATCH 47/88] 7380 mention demo in bending energy and diffusion
docstrings (#7381)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7380.
### Description
Mention
[demo](https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb)
in bending energy and diffusion docstrings.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: kaibo
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/losses/deform.py | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/monai/losses/deform.py b/monai/losses/deform.py
index 129abeedd2..37e4468d4b 100644
--- a/monai/losses/deform.py
+++ b/monai/losses/deform.py
@@ -46,7 +46,10 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor:
class BendingEnergyLoss(_Loss):
"""
- Calculate the bending energy based on second-order differentiation of pred using central finite difference.
+ Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference.
+
+ For more information,
+ see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
@@ -75,6 +78,9 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
+ ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
+ ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4.
+ ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.
"""
if pred.ndim not in [3, 4, 5]:
@@ -84,7 +90,8 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}")
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
- f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}"
+ f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, "
+ f"does not match number of spatial dimensions, {pred.ndim - 2}"
)
# first order gradient
@@ -120,12 +127,15 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
class DiffusionLoss(_Loss):
"""
- Calculate the diffusion based on first-order differentiation of pred using central finite difference.
+ Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference.
For the original paper, please refer to
VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.
+ For more information,
+ see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.
+
Adapted from:
VoxelMorph (https://github.com/voxelmorph/voxelmorph)
"""
From 2e382b46a5dfcd8b70ded012cb7e2a35a33fce9a Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 12 Jan 2024 20:25:42 +0800
Subject: [PATCH 48/88] Pin gdown version to v4.6.3 (#7384)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Workaround for #7382 #7383
### Description
Based on the comment
[here](https://github.com/wkentaro/gdown/issues/291#issuecomment-1887060708),
pin the gdown version as a workaround. Will review this one once gdown
has some update internal.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/utils.py | 2 +-
requirements-dev.txt | 2 +-
setup.cfg | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/monai/apps/utils.py b/monai/apps/utils.py
index d2dd63b958..442dbabba0 100644
--- a/monai/apps/utils.py
+++ b/monai/apps/utils.py
@@ -30,7 +30,7 @@
from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import
-gdown, has_gdown = optional_import("gdown", "4.4")
+gdown, has_gdown = optional_import("gdown", "4.6.3")
if TYPE_CHECKING:
from tqdm import tqdm
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 4685cd1572..f8bc9d5a3e 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,7 +1,7 @@
# Full requirements for developments
-r requirements-min.txt
pytorch-ignite==0.4.11
-gdown>=4.4.0
+gdown>=4.4.0, <=4.6.3
scipy>=1.7.1
itk>=5.2
nibabel
diff --git a/setup.cfg b/setup.cfg
index b38974a1a4..1141dc0ef8 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -52,7 +52,7 @@ all =
scipy>=1.7.1
pillow
tensorboard
- gdown>=4.4.0
+ gdown==4.6.3
pytorch-ignite==0.4.11
torchvision
itk>=5.2
@@ -97,7 +97,7 @@ pillow =
tensorboard =
tensorboard
gdown =
- gdown>=4.4.0
+ gdown==4.6.3
ignite =
pytorch-ignite==0.4.11
torchvision =
From 7e7d278f6c42b6ddd560fc4c3f4aca8783c167c0 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 18 Jan 2024 18:01:02 +0800
Subject: [PATCH 49/88] Fix Premerge (#7397)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7396
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
setup.cfg | 2 ++
tests/test_flexible_unet.py | 2 +-
tests/test_invertd.py | 12 ++++++------
3 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index 1141dc0ef8..0069214de3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -173,6 +173,7 @@ max_line_length = 120
# B028 https://github.com/Project-MONAI/MONAI/issues/5855
# B907 https://github.com/Project-MONAI/MONAI/issues/5868
# B908 https://github.com/Project-MONAI/MONAI/issues/6503
+# B036 https://github.com/Project-MONAI/MONAI/issues/7396
ignore =
E203
E501
@@ -186,6 +187,7 @@ ignore =
B028
B907
B908
+ B036
per_file_ignores = __init__.py: F401, __main__.py: F401
exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py
diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py
index 1218ce6e85..1d831f0976 100644
--- a/tests/test_flexible_unet.py
+++ b/tests/test_flexible_unet.py
@@ -39,7 +39,7 @@ class DummyEncoder(BaseEncoder):
def get_encoder_parameters(cls):
basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False}
param_dict_list = [basic_dict]
- for key in basic_dict:
+ for key in basic_dict.keys():
cur_dict = basic_dict.copy()
del cur_dict[key]
param_dict_list.append(cur_dict)
diff --git a/tests/test_invertd.py b/tests/test_invertd.py
index cd2e91257a..2e6ee35981 100644
--- a/tests/test_invertd.py
+++ b/tests/test_invertd.py
@@ -112,15 +112,15 @@ def test_invert(self):
self.assertTupleEqual(i.shape[1:], (101, 100, 107))
# check the case that different items use different interpolation mode to invert transforms
- d = item["image_inverted1"]
+ j = item["image_inverted1"]
# if the interpolation mode is nearest, accumulated diff should be smaller than 1
- self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0)
- self.assertTupleEqual(d.shape, (1, 101, 100, 107))
+ self.assertLess(torch.sum(j.to(torch.float) - j.to(torch.uint8).to(torch.float)).item(), 1.0)
+ self.assertTupleEqual(j.shape, (1, 101, 100, 107))
- d = item["label_inverted1"]
+ k = item["label_inverted1"]
# if the interpolation mode is not nearest, accumulated diff should be greater than 10000
- self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0)
- self.assertTupleEqual(d.shape, (1, 101, 100, 107))
+ self.assertGreater(torch.sum(k.to(torch.float) - k.to(torch.uint8).to(torch.float)).item(), 10000.0)
+ self.assertTupleEqual(k.shape, (1, 101, 100, 107))
# check labels match
reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32)
From ad7e3fa34f6754e6759c16a7a7d1d400d07d41fb Mon Sep 17 00:00:00 2001
From: "axel.vlaminck"
Date: Thu, 18 Jan 2024 18:07:22 +0100
Subject: [PATCH 50/88] Track applied operations in image filter (#7395)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7394
### Description
When ImageFilter is in the transformation sequence it didn't pass the
applied_operations.
Now it is passed when present.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: axel.vlaminck
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/utility/array.py | 11 ++++++++---
tests/test_image_filter.py | 16 ++++++++++++++++
2 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py
index 2322f2123f..5dfbcb0e91 100644
--- a/monai/transforms/utility/array.py
+++ b/monai/transforms/utility/array.py
@@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int |
self.filter_size = filter_size
self.additional_args_for_filter = kwargs
- def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor:
+ def __call__(
+ self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None
+ ) -> NdarrayOrTensor:
"""
Args:
img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]
meta_dict: An optional dictionary with metadata
+ applied_operations: An optional list of operations that have been applied to the data
Returns:
A MetaTensor with the same shape as `img` and identical metadata
"""
if isinstance(img, MetaTensor):
meta_dict = img.meta
+ applied_operations = img.applied_operations
+
img_, prev_type, device = convert_data_type(img, torch.Tensor)
ndim = img_.ndim - 1 # assumes channel first format
@@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr
self.filter = ApplyFilter(self.filter)
img_ = self._apply_filter(img_)
- if meta_dict:
- img_ = MetaTensor(img_, meta=meta_dict)
+ if meta_dict is not None or applied_operations is not None:
+ img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations)
else:
img_, *_ = convert_data_type(img_, prev_type, device)
return img_
diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py
index 841a5d5cd5..985ea95e79 100644
--- a/tests/test_image_filter.py
+++ b/tests/test_image_filter.py
@@ -17,6 +17,7 @@
import torch
from parameterized import parameterized
+from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd
@@ -115,6 +116,21 @@ def test_call_3d(self, filter_name):
out_tensor = filter(SAMPLE_IMAGE_3D)
self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:])
+ def test_pass_applied_operations(self):
+ "Test that applied operations are passed through"
+ applied_operations = ["op1", "op2"]
+ image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations)
+ filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)
+ out_tensor = filter(image)
+ self.assertEqual(out_tensor.applied_operations, applied_operations)
+
+ def test_pass_empty_metadata_dict(self):
+ "Test that applied operations are passed through"
+ image = MetaTensor(SAMPLE_IMAGE_2D, meta={})
+ filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)
+ out_tensor = filter(image)
+ self.assertTrue(isinstance(out_tensor, MetaTensor))
+
class TestImageFilterDict(unittest.TestCase):
@parameterized.expand(SUPPORTED_FILTERS)
From aef4b577ae234630a3f6a2057630dd4fc2223c4a Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 19 Jan 2024 16:00:13 +0800
Subject: [PATCH 51/88] Add `compile` support in `SupervisedTrainer` and
`SupervisedEvaluator` (#7375)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes # .
### Description
Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator`.
Convert to `torch.Tensor` internally.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/engines/evaluator.py | 51 +++++++++++++++++++++++++++++++++++--
monai/engines/trainer.py | 52 ++++++++++++++++++++++++++++++++++++--
2 files changed, 99 insertions(+), 4 deletions(-)
diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py
index 119853d5c5..2c8dfe6b85 100644
--- a/monai/engines/evaluator.py
+++ b/monai/engines/evaluator.py
@@ -11,12 +11,14 @@
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
import torch
from torch.utils.data import DataLoader
from monai.config import IgniteInfo, KeysCollection
+from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
@@ -25,7 +27,7 @@
from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
-from monai.utils.module import look_up_option
+from monai.utils.module import look_up_option, pytorch_after
if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
@@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
+ compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
+ `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
+ compile_kwargs: dict of the args for `torch.compile()` API, for more details:
+ https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
"""
@@ -238,6 +244,8 @@ def __init__(
decollate: bool = True,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
+ compile: bool = False,
+ compile_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
@@ -259,8 +267,16 @@ def __init__(
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)
-
+ if compile:
+ if pytorch_after(2, 1):
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
+ network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
+ else:
+ warnings.warn(
+ "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
+ )
self.network = network
+ self.compile = compile
self.inferer = SimpleInferer() if inferer is None else inferer
def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:
@@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
+ # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026
+ if self.compile:
+ inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None
+ if isinstance(inputs, MetaTensor):
+ warnings.warn(
+ "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass."
+ )
+ inputs, inputs_meta, inputs_applied_operations = (
+ inputs.as_tensor(),
+ inputs.meta,
+ inputs.applied_operations,
+ )
+ if isinstance(targets, MetaTensor):
+ targets, targets_meta, targets_applied_operations = (
+ targets.as_tensor(),
+ targets.meta,
+ targets.applied_operations,
+ )
# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
@@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
else:
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
+ # copy back meta info
+ if self.compile:
+ if inputs_meta is not None:
+ engine.state.output[Keys.IMAGE] = MetaTensor(
+ inputs, meta=inputs_meta, applied_operations=inputs_applied_operations
+ )
+ engine.state.output[Keys.PRED] = MetaTensor(
+ engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations
+ )
+ if targets_meta is not None:
+ engine.state.output[Keys.LABEL] = MetaTensor(
+ targets, meta=targets_meta, applied_operations=targets_applied_operations
+ )
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)
diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py
index 61b7028e11..f1513ea73b 100644
--- a/monai/engines/trainer.py
+++ b/monai/engines/trainer.py
@@ -11,6 +11,7 @@
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
import torch
@@ -18,6 +19,7 @@
from torch.utils.data import DataLoader
from monai.config import IgniteInfo
+from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
@@ -25,6 +27,7 @@
from monai.utils import GanKeys, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
+from monai.utils.module import pytorch_after
if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
@@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
-
+ compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
+ `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
+ compile_kwargs: dict of the args for `torch.compile()` API, for more details:
+ https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
"""
def __init__(
@@ -153,6 +159,8 @@ def __init__(
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
+ compile: bool = False,
+ compile_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
@@ -174,8 +182,16 @@ def __init__(
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)
-
+ if compile:
+ if pytorch_after(2, 1):
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
+ network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
+ else:
+ warnings.warn(
+ "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
+ )
self.network = network
+ self.compile = compile
self.optimizer = optimizer
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer
@@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
+ # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026
+ if self.compile:
+ inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None
+ if isinstance(inputs, MetaTensor):
+ warnings.warn(
+ "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass."
+ )
+ inputs, inputs_meta, inputs_applied_operations = (
+ inputs.as_tensor(),
+ inputs.meta,
+ inputs.applied_operations,
+ )
+ if isinstance(targets, MetaTensor):
+ targets, targets_meta, targets_applied_operations = (
+ targets.as_tensor(),
+ targets.meta,
+ targets.applied_operations,
+ )
+
# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
@@ -231,6 +266,19 @@ def _compute_pred_loss():
engine.state.output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.optimizer.step()
+ # copy back meta info
+ if self.compile:
+ if inputs_meta is not None:
+ engine.state.output[Keys.IMAGE] = MetaTensor(
+ inputs, meta=inputs_meta, applied_operations=inputs_applied_operations
+ )
+ engine.state.output[Keys.PRED] = MetaTensor(
+ engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations
+ )
+ if targets_meta is not None:
+ engine.state.output[Keys.LABEL] = MetaTensor(
+ targets, meta=targets_meta, applied_operations=targets_applied_operations
+ )
engine.fire_event(IterationEvents.MODEL_COMPLETED)
return engine.state.output
From d8236eab639140da550f3e1f2e5b9e1b5b494f8a Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Mon, 22 Jan 2024 23:58:29 +0800
Subject: [PATCH 52/88] Fix CUDA_VISIBLE_DEVICES setting ignored (#7408)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7407
### Description
Move `optional import cucim` inside the function to avoid using all
GPUs.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/metrics/utils.py | 6 ++----
tests/test_set_visible_devices.py | 7 +++++++
2 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py
index d4b8f6e9b6..e7057256fb 100644
--- a/monai/metrics/utils.py
+++ b/monai/metrics/utils.py
@@ -38,10 +38,6 @@
binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion")
distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt")
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
-cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion")
-cucim_distance_transform_edt, has_cucim_distance_transform_edt = optional_import(
- "cucim.core.operations.morphology", name="distance_transform_edt"
-)
__all__ = [
"ignore_background",
@@ -179,6 +175,8 @@ def get_mask_edges(
always_return_as_numpy: whether to a numpy array regardless of the input type.
If False, return the same type as inputs.
"""
+ # move in the funciton to avoid using all the GPUs
+ cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion")
if seg_pred.shape != seg_gt.shape:
raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.")
converter: Any
diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py
index 53703e107a..993e8a4ac2 100644
--- a/tests/test_set_visible_devices.py
+++ b/tests/test_set_visible_devices.py
@@ -35,6 +35,13 @@ def test_visible_devices(self):
)
self.assertEqual(num_gpus_before, num_gpus_after)
+ # test import monai won't affect setting CUDA_VISIBLE_DEVICES
+ num_gpus_after_monai = self.run_process_and_get_exit_code(
+ 'python -c "import os; import torch; import monai; '
+ + "os.environ['CUDA_VISIBLE_DEVICES'] = '0'; exit(torch.cuda.device_count())\""
+ )
+ self.assertEqual(num_gpus_after_monai, 1)
+
if __name__ == "__main__":
unittest.main()
From 433a3aa6154bfaa5fec9ba600af4655de8a1d05f Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 25 Jan 2024 09:47:36 +0800
Subject: [PATCH 53/88] Fix Incorrect updated affine in `NrrdReader` and update
docstring in `ITKReader` (#7415)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7414
Fixes #7371
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/data/image_reader.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py
index 0823d11834..2361bb63a7 100644
--- a/monai/data/image_reader.py
+++ b/monai/data/image_reader.py
@@ -168,8 +168,8 @@ class ITKReader(ImageReader):
series_name: the name of the DICOM series if there are multiple ones.
used when loading DICOM series.
reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array.
- If ``False``, the spatial indexing follows the numpy convention;
- otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``.
+ If ``False``, the spatial indexing convention is reversed to be compatible with ITK;
+ otherwise, the spatial indexing follows the numpy convention. Default is ``False``.
This option does not affect the metadata.
series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice).
This flag is checked only when loading DICOM series. Default is ``False``.
@@ -1323,7 +1323,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
header = dict(i.header)
if self.index_order == "C":
header = self._convert_f_to_c_order(header)
- header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
+ header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header)
if self.affine_lps_to_ras:
header = self._switch_lps_ras(header)
@@ -1344,7 +1344,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
return _stack_images(img_array, compatible_meta), compatible_meta
- def _get_affine(self, img: NrrdImage) -> np.ndarray:
+ def _get_affine(self, header: dict) -> np.ndarray:
"""
Get the affine matrix of the image, it can be used to correct
spacing, orientation or execute spatial transforms.
@@ -1353,8 +1353,8 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray:
img: A `NrrdImage` loaded from image file
"""
- direction = img.header["space directions"]
- origin = img.header["space origin"]
+ direction = header["space directions"]
+ origin = header["space origin"]
x, y = direction.shape
affine_diam = min(x, y) + 1
From 1b4091ce8ea84025493998688ed51a5995bc5f53 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Tue, 30 Jan 2024 09:42:03 +0800
Subject: [PATCH 54/88] Ignore E704 after update black (#7422)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7421
### Description
https://pypi.org/project/black/24.1.1/
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/detection/utils/anchor_utils.py | 8 ++++++--
monai/data/decathlon_datalist.py | 6 ++----
monai/losses/image_dissimilarity.py | 4 +---
monai/transforms/utility/dictionary.py | 6 +++---
monai/utils/dist.py | 9 +++------
monai/utils/misc.py | 6 ++----
setup.cfg | 2 ++
tests/test_hilbert_transform.py | 20 +++++++++++---------
tests/test_spacing.py | 8 +++++---
9 files changed, 35 insertions(+), 34 deletions(-)
diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py
index baaa7ce874..283169b653 100644
--- a/monai/apps/detection/utils/anchor_utils.py
+++ b/monai/apps/detection/utils/anchor_utils.py
@@ -369,8 +369,12 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator):
def __init__(
self,
feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8),
- base_anchor_shapes: Sequence[Sequence[int]]
- | Sequence[Sequence[float]] = ((32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48)),
+ base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = (
+ (32, 32, 32),
+ (48, 20, 20),
+ (20, 48, 20),
+ (20, 20, 48),
+ ),
indexing: str = "ij",
) -> None:
nn.Module.__init__(self)
diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py
index 6f163f972e..14765dcfaa 100644
--- a/monai/data/decathlon_datalist.py
+++ b/monai/data/decathlon_datalist.py
@@ -24,13 +24,11 @@
@overload
-def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str:
- ...
+def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ...
@overload
-def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]:
- ...
+def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ...
def _compute_path(base_dir, element, check_path=False):
diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py
index 39219e059a..dd132770ec 100644
--- a/monai/losses/image_dissimilarity.py
+++ b/monai/losses/image_dissimilarity.py
@@ -277,9 +277,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torc
if order == 0:
weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5
elif order == 3:
- weight = (
- weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6
- )
+ weight = weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6
weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6
else:
raise ValueError(f"Do not support b-spline {order}-order parzen windowing")
diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py
index ec10bd8537..1cd9ff6323 100644
--- a/monai/transforms/utility/dictionary.py
+++ b/monai/transforms/utility/dictionary.py
@@ -1765,9 +1765,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
LabelToMaskD = LabelToMaskDict = LabelToMaskd
FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd
ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd
-ConvertToMultiChannelBasedOnBratsClassesD = (
- ConvertToMultiChannelBasedOnBratsClassesDict
-) = ConvertToMultiChannelBasedOnBratsClassesd
+ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = (
+ ConvertToMultiChannelBasedOnBratsClassesd
+)
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
diff --git a/monai/utils/dist.py b/monai/utils/dist.py
index 20f09628ac..2418b43591 100644
--- a/monai/utils/dist.py
+++ b/monai/utils/dist.py
@@ -50,18 +50,15 @@ def get_dist_device():
@overload
-def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor:
- ...
+def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ...
@overload
-def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]:
- ...
+def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ...
@overload
-def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]:
- ...
+def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ...
def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]:
diff --git a/monai/utils/misc.py b/monai/utils/misc.py
index d6ff370f69..2a5c5da136 100644
--- a/monai/utils/misc.py
+++ b/monai/utils/misc.py
@@ -103,13 +103,11 @@ def star_zip_with(op, *vals):
@overload
-def first(iterable: Iterable[T], default: T) -> T:
- ...
+def first(iterable: Iterable[T], default: T) -> T: ...
@overload
-def first(iterable: Iterable[T]) -> T | None:
- ...
+def first(iterable: Iterable[T]) -> T | None: ...
def first(iterable: Iterable[T], default: T | None = None) -> T | None:
diff --git a/setup.cfg b/setup.cfg
index 0069214de3..4180ced917 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -174,6 +174,7 @@ max_line_length = 120
# B907 https://github.com/Project-MONAI/MONAI/issues/5868
# B908 https://github.com/Project-MONAI/MONAI/issues/6503
# B036 https://github.com/Project-MONAI/MONAI/issues/7396
+# E704 https://github.com/Project-MONAI/MONAI/issues/7421
ignore =
E203
E501
@@ -188,6 +189,7 @@ ignore =
B907
B908
B036
+ E704
per_file_ignores = __init__.py: F401, __main__.py: F401
exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py
diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py
index 4c49aecd8b..68fa0b1192 100644
--- a/tests/test_hilbert_transform.py
+++ b/tests/test_hilbert_transform.py
@@ -180,15 +180,17 @@ def test_value(self, arguments, image, expected_data, atol):
@SkipIfNoModule("torch.fft")
class TestHilbertTransformGPU(unittest.TestCase):
@parameterized.expand(
- []
- if not torch.cuda.is_available()
- else [
- TEST_CASE_1D_SINE_GPU,
- TEST_CASE_2D_SINE_GPU,
- TEST_CASE_3D_SINE_GPU,
- TEST_CASE_1D_2CH_SINE_GPU,
- TEST_CASE_2D_2CH_SINE_GPU,
- ],
+ (
+ []
+ if not torch.cuda.is_available()
+ else [
+ TEST_CASE_1D_SINE_GPU,
+ TEST_CASE_2D_SINE_GPU,
+ TEST_CASE_3D_SINE_GPU,
+ TEST_CASE_1D_2CH_SINE_GPU,
+ TEST_CASE_2D_2CH_SINE_GPU,
+ ]
+ ),
skip_on_empty=True,
)
def test_value(self, arguments, image, expected_data, atol):
diff --git a/tests/test_spacing.py b/tests/test_spacing.py
index 1ff1518297..8b664641d7 100644
--- a/tests/test_spacing.py
+++ b/tests/test_spacing.py
@@ -74,9 +74,11 @@
torch.ones((1, 2, 1, 2)), # data
torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]),
{},
- torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]])
- if USE_COMPILED
- else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]),
+ (
+ torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]])
+ if USE_COMPILED
+ else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]])
+ ),
*device,
]
)
From f8bfc7cec60332e522920bd9fb10e786f20a4da9 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 1 Feb 2024 19:47:32 +0800
Subject: [PATCH 55/88] update `rm -rf /opt/hostedtoolcache` avoid change the
python version (#7424)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7416
### Description
update `rm -rf /opt/hostedtoolcache` avoid change the python version
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/docker.yml | 2 +-
.github/workflows/release.yml | 6 +++---
.github/workflows/setupapp.yml | 4 ++--
monai/transforms/smooth_field/array.py | 2 +-
4 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 229ae675f5..065125cc33 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -62,7 +62,7 @@ jobs:
- name: docker_build
shell: bash
run: |
- rm -rf /opt/hostedtoolcache
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
docker --version
# get tag info for versioning
cat _version.py
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index a03d2cea6c..c134724665 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -27,7 +27,7 @@ jobs:
python -m pip install --user --upgrade setuptools wheel
- name: Build and test source archive and wheel file
run: |
- rm -rf /opt/hostedtoolcache
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
root_dir=$PWD
echo "$root_dir"
@@ -102,7 +102,7 @@ jobs:
python-version: '3.9'
- shell: bash
run: |
- rm -rf /opt/hostedtoolcache
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
git describe
python -m pip install --user --upgrade setuptools wheel
python setup.py build
@@ -143,7 +143,7 @@ jobs:
RELEASE_VERSION: ${{ steps.versioning.outputs.tag }}
shell: bash
run: |
- rm -rf /opt/hostedtoolcache
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
# get tag info for versioning
mv _version.py monai/
# version checks
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index 82394a86dd..a6407deb33 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -100,7 +100,7 @@ jobs:
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install the dependencies
run: |
- rm -rf /opt/hostedtoolcache
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
python -m pip install --upgrade pip wheel
python -m pip install -r requirements-dev.txt
- name: Run quick tests CPU ubuntu
@@ -146,7 +146,7 @@ jobs:
- name: Install the default branch with build (dev branch only)
if: github.ref == 'refs/heads/dev'
run: |
- rm -rf /opt/hostedtoolcache
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
python -c 'import monai; monai.config.print_config()'
- name: Get the test cases (dev branch only)
diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py
index c9df5f1dbb..9d19263f8b 100644
--- a/monai/transforms/smooth_field/array.py
+++ b/monai/transforms/smooth_field/array.py
@@ -96,7 +96,7 @@ def __init__(
self.set_spatial_size(spatial_size)
def randomize(self, data: Any | None = None) -> None:
- self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size))
+ self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) # type: ignore[index]
def set_spatial_size(self, spatial_size: Sequence[int] | None) -> None:
"""
From 69e7e0562b9df4be50e7de9e4d8c660940c90b4a Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Thu, 1 Feb 2024 13:53:01 +0000
Subject: [PATCH 56/88] Bump peter-evans/slash-command-dispatch from 3.0.2 to
4.0.0 (#7428)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[peter-evans/slash-command-dispatch](https://github.com/peter-evans/slash-command-dispatch)
from 3.0.2 to 4.0.0.
Release notes
Sourced from peter-evans/slash-command-dispatch's
releases.
Slash Command Dispatch v4.0.0
⚙️ Updated runtime to Node.js 20
- The action now requires a minimum version of v2.308.0
for the Actions runner. Update self-hosted runners to v2.308.0 or later
to ensure compatibility.
What's Changed
Full Changelog: https://github.com/peter-evans/slash-command-dispatch/compare/v3.0.2...v4.0.0
Commits
13bc097
feat: update runtime to node 20 (#316)
128d8b9
build(deps-dev): bump @types/node from 16.18.74 to
16.18.76 (#318)
df481df
build(deps): bump peter-evans/create-or-update-comment from 3 to 4 (#317)
d4579a0
build(deps-dev): bump @types/node from 16.18.70 to
16.18.74 (#315)
8f053ea
build(deps-dev): bump prettier from 3.2.2 to 3.2.4 (#314)
b3eb783
build(deps-dev): bump prettier from 3.1.1 to 3.2.2 (#312)
c0334d0
build(deps-dev): bump eslint-plugin-prettier from 5.1.2 to 5.1.3 (#311)
e627c61
build(deps-dev): bump @types/node from 16.18.69 to
16.18.70 (#310)
5c23a33
build(deps-dev): bump @types/node from 16.18.68 to
16.18.69 (#309)
8dd62d5
build(deps-dev): bump eslint-plugin-prettier from 5.0.1 to 5.1.2 (#308)
- Additional commits viewable in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/chatops.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml
index 59c7d070b4..6f3b1c293d 100644
--- a/.github/workflows/chatops.yml
+++ b/.github/workflows/chatops.yml
@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: dispatch
- uses: peter-evans/slash-command-dispatch@v3.0.2
+ uses: peter-evans/slash-command-dispatch@v4.0.0
with:
token: ${{ secrets.PR_MAINTAIN }}
reaction-token: ${{ secrets.GITHUB_TOKEN }}
From bcea0e8c50d0493670482b1acdcc6fcc2a827b62 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Fri, 2 Feb 2024 16:14:04 +0800
Subject: [PATCH 57/88] Bump peter-evans/create-or-update-comment from 3 to 4
(#7429)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[peter-evans/create-or-update-comment](https://github.com/peter-evans/create-or-update-comment)
from 3 to 4.
Release notes
Sourced from peter-evans/create-or-update-comment's
releases.
Create or Update Comment v4.0.0
⚙️ Updated runtime to Node.js 20
- The action now requires a minimum version of v2.308.0
for the Actions runner. Update self-hosted runners to v2.308.0 or later
to ensure compatibility.
What's Changed
Full Changelog: https://github.com/peter-evans/create-or-update-comment/compare/v3.1.0...v4.0.0
Create or Update Comment v3.1.0
What's Changed
Full Changelog: https://github.com/peter-evans/create-or-update-comment/compare/v3.0.2...v3.1.0
Create or Update Comment v3.0.2
What's Changed
... (truncated)
Commits
71345be
feat: update runtime to node 20 (#306)
d41bfe3
build(deps-dev): bump prettier from 3.2.3 to 3.2.4 (#305)
73b4b9e
build(deps-dev): bump @types/node from 18.19.7 to 18.19.8
(#304)
b865fac
build(deps-dev): bump @types/node from 18.19.6 to 18.19.7
(#303)
52b668a
build(deps-dev): bump eslint-plugin-jest from 27.6.1 to 27.6.3 (#302)
974f56a
build(deps-dev): bump prettier from 3.1.1 to 3.2.3 (#301)
2cbfe8b
build(deps-dev): bump @types/node from 18.19.4 to 18.19.6
(#300)
761872a
build(deps-dev): bump eslint-plugin-prettier from 5.1.2 to 5.1.3 (#299)
72c3238
build(deps-dev): bump @types/node from 18.19.3 to 18.19.4
(#298)
07daf7b
build(deps-dev): bump eslint-plugin-jest from 27.6.0 to 27.6.1 (#297)
- Additional commits viewable in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/integration.yml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index 952b2d8deb..c239f9d5fe 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -71,7 +71,7 @@ jobs:
run: ./runtests.sh --build --net
- name: Add reaction
- uses: peter-evans/create-or-update-comment@v3
+ uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ secrets.PR_MAINTAIN }}
repository: ${{ github.event.client_payload.github.payload.repository.full_name }}
@@ -151,7 +151,7 @@ jobs:
python -m tests.test_integration_gpu_customization
- name: Add reaction
- uses: peter-evans/create-or-update-comment@v3
+ uses: peter-evans/create-or-update-comment@v4
with:
token: ${{ secrets.PR_MAINTAIN }}
repository: ${{ github.event.client_payload.github.payload.repository.full_name }}
From 748dce536fd368409516a1a072da79cf58405d69 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Fri, 2 Feb 2024 09:11:02 +0000
Subject: [PATCH 58/88] Bump actions/cache from 3 to 4 (#7430)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4.
Release notes
Sourced from actions/cache's
releases.
v4.0.0
What's Changed
New Contributors
Full Changelog: https://github.com/actions/cache/compare/v3...v4.0.0
v3.3.3
What's Changed
New Contributors
Full Changelog: https://github.com/actions/cache/compare/v3...v3.3.3
v3.3.2
What's Changed
New Contributors
Full Changelog: https://github.com/actions/cache/compare/v3...v3.3.2
v3.3.1
What's Changed
Full Changelog: https://github.com/actions/cache/compare/v3...v3.3.1
v3.3.0
What's Changed
... (truncated)
Changelog
Sourced from actions/cache's
changelog.
Releases
3.0.0
- Updated minimum runner version support from node 12 -> node
16
3.0.1
- Added support for caching from GHES 3.5.
- Fixed download issue for files > 2GB during restore.
3.0.2
- Added support for dynamic cache size cap on GHES.
3.0.3
- Fixed avoiding empty cache save when no files are available for
caching. (issue)
3.0.4
- Fixed tar creation error while trying to create tar with path as
~/ home folder on ubuntu-latest. (issue)
3.0.5
- Removed error handling by consuming actions/cache 3.0 toolkit, Now
cache server error handling will be done by toolkit. (PR)
3.0.6
- Fixed #809 -
zstd -d: no such file or directory error
- Fixed #833 -
cache doesn't work with github workspace directory
3.0.7
- Fixed #810 -
download stuck issue. A new timeout is introduced in the download
process to abort the download if it gets stuck and doesn't finish within
an hour.
3.0.8
- Fix zstd not working for windows on gnu tar in issues #888 and
#891.
- Allowing users to provide a custom timeout as input for aborting
download of a cache segment using an environment variable
SEGMENT_DOWNLOAD_TIMEOUT_MINS. Default is 60 minutes.
3.0.9
- Enhanced the warning message for cache unavailablity in case of
GHES.
3.0.10
- Fix a bug with sorting inputs.
- Update definition for restore-keys in README.md
... (truncated)
Commits
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/cron-ngc-bundle.yml | 2 +-
.github/workflows/integration.yml | 4 ++--
.github/workflows/pythonapp-min.yml | 6 +++---
.github/workflows/pythonapp.yml | 8 ++++----
.github/workflows/setupapp.yml | 6 +++---
5 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml
index 84666204a9..bd45bc8d1e 100644
--- a/.github/workflows/cron-ngc-bundle.yml
+++ b/.github/workflows/cron-ngc-bundle.yml
@@ -26,7 +26,7 @@ jobs:
id: pip-cache
run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: ~/.cache/pip
diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index c239f9d5fe..c82530a551 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -22,7 +22,7 @@ jobs:
id: pip-cache
run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
@@ -95,7 +95,7 @@ jobs:
id: pip-cache
run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml
index 7b7930bdf5..bbe7579774 100644
--- a/.github/workflows/pythonapp-min.yml
+++ b/.github/workflows/pythonapp-min.yml
@@ -44,7 +44,7 @@ jobs:
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: ${{ steps.pip-cache.outputs.dir }}
@@ -90,7 +90,7 @@ jobs:
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: ${{ steps.pip-cache.outputs.dir }}
@@ -135,7 +135,7 @@ jobs:
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: ${{ steps.pip-cache.outputs.dir }}
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index 29a79759e0..b011e65cf1 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -36,7 +36,7 @@ jobs:
run: |
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: ~/.cache/pip
@@ -83,7 +83,7 @@ jobs:
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: ${{ steps.pip-cache.outputs.dir }}
@@ -136,7 +136,7 @@ jobs:
run: |
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
@@ -217,7 +217,7 @@ jobs:
run: |
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index a6407deb33..8f95ffdfc0 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -35,7 +35,7 @@ jobs:
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
if: ${{ startsWith(github.ref, 'refs/heads/dev') }}
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
@@ -91,7 +91,7 @@ jobs:
run: |
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
@@ -128,7 +128,7 @@ jobs:
run: |
echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
- uses: actions/cache@v3
+ uses: actions/cache@v4
id: cache
with:
path: |
From f9b4fc2a0cfb14d37f59a2704ccbef7432079bf4 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Fri, 2 Feb 2024 21:29:44 +0800
Subject: [PATCH 59/88] Bump codecov/codecov-action from 3 to 4 (#7431)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[codecov/codecov-action](https://github.com/codecov/codecov-action) from
3 to 4.
Release notes
Sourced from codecov/codecov-action's
releases.
v4.0.0
v4 of the Codecov Action uses the CLI as the
underlying upload. The CLI has helped to power new features including
local upload, the global upload token, and new upcoming features.
Breaking Changes
- The Codecov Action runs as a
node20 action due to
node16 deprecation. See this
post from GitHub on how to migrate.
- Tokenless uploading is unsupported. However, PRs made from forks to
the upstream public repos will support tokenless (e.g. contributors to
OS projects do not need the upstream repo's Codecov token). This doc
shows instructions on how to add the Codecov token.
- OS platforms have been added, though some may not be automatically
detected. To see a list of platforms, see our CLI download page
- Various arguments to the Action have been changed. Please be aware
that the arguments match with the CLI's needs
v3 versions and below will not have access to CLI
features (e.g. global upload token, ATS).
What's Changed
... (truncated)
Changelog
Sourced from codecov/codecov-action's
changelog.
4.0.0-beta.2
Fixes
- #1085
not adding -n if empty to do-upload command
4.0.0-beta.1
v4 represents a move from the universal uploader to the
Codecov CLI.
Although this will unlock new features for our users, the CLI is not yet
at feature parity with the universal uploader.
Breaking Changes
- No current support for
aarch64 and alpine
architectures.
- Tokenless uploading is unsuported
- Various arguments to the Action have been removed
3.1.4
Fixes
- #967
Fix typo in README.md
- #971
fix: add back in working dir
- #969
fix: CLI option names for uploader
Dependencies
- #970
build(deps-dev): bump
@types/node from 18.15.12 to
18.16.3
- #979
build(deps-dev): bump
@types/node from 20.1.0 to
20.1.2
- #981
build(deps-dev): bump
@types/node from 20.1.2 to
20.1.4
3.1.3
Fixes
- #960
fix: allow for aarch64 build
Dependencies
- #957
build(deps-dev): bump jest-junit from 15.0.0 to 16.0.0
- #958
build(deps): bump openpgp from 5.7.0 to 5.8.0
- #959
build(deps-dev): bump
@types/node from 18.15.10 to
18.15.12
3.1.2
Fixes
- #718
Update README.md
- #851
Remove unsupported path_to_write_report argument
- #898
codeql-analysis.yml
- #901
Update README to contain correct information - inputs and negate
feature
- #955
fix: add in all the extra arguments for uploader
Dependencies
- #819
build(deps): bump openpgp from 5.4.0 to 5.5.0
- #835
build(deps): bump node-fetch from 3.2.4 to 3.2.10
- #840
build(deps): bump ossf/scorecard-action from 1.1.1 to 2.0.4
- #841
build(deps): bump
@actions/core from 1.9.1 to 1.10.0
- #843
build(deps): bump
@actions/github from 5.0.3 to 5.1.1
- #869
build(deps): bump node-fetch from 3.2.10 to 3.3.0
- #872
build(deps-dev): bump jest-junit from 13.2.0 to 15.0.0
- #879
build(deps): bump decode-uri-component from 0.2.0 to 0.2.2
... (truncated)
Commits
f30e495
fix: update action.yml (#1240)
a7b945c
fix: allow for other archs (#1239)
98ab2c5
Update package.json (#1238)
43235cc
Update README.md (#1237)
0cf8684
chore(ci): bump to node20 (#1236)
8e1e730
build(deps-dev): bump @typescript-eslint/eslint-plugin
from 6.19.1 to 6.20.0 ...
61293af
build(deps-dev): bump @typescript-eslint/parser from
6.19.1 to 6.20.0 (#1235)
7a070cb
build(deps): bump github/codeql-action from 3.23.1 to 3.23.2 (#1231)
9097165
build(deps): bump actions/upload-artifact from 4.2.0 to 4.3.0 (#1232)
ac042ea
build(deps-dev): bump @typescript-eslint/eslint-plugin
from 6.19.0 to 6.19.1 ...
- Additional commits viewable in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/cron.yml | 6 +++---
.github/workflows/pythonapp-gpu.yml | 2 +-
.github/workflows/setupapp.yml | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml
index e981280ff9..792fda5279 100644
--- a/.github/workflows/cron.yml
+++ b/.github/workflows/cron.yml
@@ -67,7 +67,7 @@ jobs:
if pgrep python; then pkill python; fi
shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v3
+ uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
files: ./coverage.xml
@@ -111,7 +111,7 @@ jobs:
if pgrep python; then pkill python; fi
shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v3
+ uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
files: ./coverage.xml
@@ -212,7 +212,7 @@ jobs:
if pgrep python; then pkill python; fi
shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v3
+ uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
files: ./coverage.xml
diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml
index 0baef949f0..a6d7981814 100644
--- a/.github/workflows/pythonapp-gpu.yml
+++ b/.github/workflows/pythonapp-gpu.yml
@@ -137,6 +137,6 @@ jobs:
shell: bash
- name: Upload coverage
if: ${{ github.head_ref != 'dev' && github.event.pull_request.merged != true }}
- uses: codecov/codecov-action@v3
+ uses: codecov/codecov-action@v4
with:
files: ./coverage.xml
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index 8f95ffdfc0..c6ad243b81 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -68,7 +68,7 @@ jobs:
if pgrep python; then pkill python; fi
shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v3
+ uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
files: ./coverage.xml
@@ -111,7 +111,7 @@ jobs:
BUILD_MONAI=1 ./runtests.sh --build --quick --min
coverage xml --ignore-errors
- name: Upload coverage
- uses: codecov/codecov-action@v3
+ uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
files: ./coverage.xml
From 0dc013d1b2f4b4bc70eff6a6f2ab20e2e100d0b1 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 2 Feb 2024 22:56:30 +0800
Subject: [PATCH 60/88] Update tensorboard version to fix deadlock (#7435)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7434
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
requirements-dev.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index f8bc9d5a3e..706980576c 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -6,7 +6,7 @@ scipy>=1.7.1
itk>=5.2
nibabel
pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571
-tensorboard>=2.6 # https://github.com/Project-MONAI/MONAI/issues/5776
+tensorboard>=2.12.0 # https://github.com/Project-MONAI/MONAI/issues/7434
scikit-image>=0.19.0
tqdm>=4.47.0
lmdb
From c3ca41cd92de2ccff355db4b3817c0c21ee20c7a Mon Sep 17 00:00:00 2001
From: monai-bot <64792179+monai-bot@users.noreply.github.com>
Date: Mon, 5 Feb 2024 15:43:23 +0000
Subject: [PATCH 61/88] auto updates (#7439)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: monai-bot
Signed-off-by: monai-bot
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/datasets.py | 1 +
monai/apps/deepedit/transforms.py | 3 +++
monai/apps/detection/metrics/coco.py | 1 +
monai/apps/detection/utils/ATSS_matcher.py | 1 +
monai/apps/pathology/transforms/post/array.py | 6 +++---
monai/data/dataset.py | 1 +
monai/fl/client/client_algo.py | 1 +
monai/handlers/ignite_metric.py | 1 +
monai/metrics/f_beta_score.py | 1 +
monai/networks/blocks/dynunet_block.py | 1 +
monai/networks/blocks/localnet_block.py | 2 ++
monai/networks/blocks/pos_embed_utils.py | 1 +
monai/networks/layers/gmm.py | 1 +
monai/networks/layers/simplelayers.py | 2 ++
monai/networks/layers/spatial_transforms.py | 5 +++++
monai/networks/nets/ahnet.py | 6 ++++++
monai/networks/nets/attentionunet.py | 4 ++++
monai/networks/nets/basic_unet.py | 1 +
monai/networks/nets/basic_unetplusplus.py | 1 +
monai/networks/nets/densenet.py | 3 +++
monai/networks/nets/dints.py | 4 ++++
monai/networks/nets/efficientnet.py | 4 ++++
monai/networks/nets/highresnet.py | 1 +
monai/networks/nets/hovernet.py | 6 ++++++
monai/networks/nets/milmodel.py | 1 +
monai/networks/nets/regunet.py | 2 ++
monai/networks/nets/vnet.py | 5 +++++
monai/optimizers/lr_finder.py | 2 ++
monai/optimizers/utils.py | 2 ++
monai/transforms/adaptors.py | 4 ++++
monai/transforms/inverse_batch_transform.py | 1 +
monai/transforms/post/array.py | 1 +
monai/transforms/spatial/array.py | 2 +-
monai/transforms/utility/dictionary.py | 1 +
monai/utils/misc.py | 1 +
monai/utils/module.py | 2 ++
monai/utils/profiling.py | 1 +
monai/visualize/class_activation_maps.py | 2 ++
monai/visualize/gradient_based.py | 1 +
tests/croppers.py | 1 +
tests/hvd_evenly_divisible_all_gather.py | 1 +
tests/ngc_bundle_download.py | 2 ++
tests/padders.py | 1 +
tests/profile_subclass/min_classes.py | 1 +
tests/test_acn_block.py | 1 +
tests/test_activations.py | 1 +
tests/test_activationsd.py | 1 +
tests/test_adaptors.py | 11 +++++++++++
tests/test_add_coordinate_channels.py | 1 +
tests/test_add_coordinate_channelsd.py | 1 +
tests/test_add_extreme_points_channel.py | 1 +
tests/test_add_extreme_points_channeld.py | 1 +
tests/test_adjust_contrast.py | 1 +
tests/test_adjust_contrastd.py | 1 +
tests/test_adn.py | 2 ++
tests/test_adversarial_loss.py | 1 +
tests/test_affine.py | 2 ++
tests/test_affine_grid.py | 1 +
tests/test_affine_transform.py | 3 +++
tests/test_affined.py | 1 +
tests/test_ahnet.py | 6 ++++++
tests/test_anchor_box.py | 1 +
tests/test_apply.py | 1 +
tests/test_apply_filter.py | 1 +
tests/test_arraydataset.py | 2 ++
tests/test_as_channel_last.py | 1 +
tests/test_as_channel_lastd.py | 1 +
tests/test_as_discrete.py | 1 +
tests/test_as_discreted.py | 1 +
tests/test_atss_box_matcher.py | 1 +
tests/test_attentionunet.py | 1 +
tests/test_auto3dseg.py | 1 +
tests/test_auto3dseg_bundlegen.py | 1 +
tests/test_auto3dseg_ensemble.py | 1 +
tests/test_auto3dseg_hpo.py | 2 ++
tests/test_autoencoder.py | 1 +
tests/test_avg_merger.py | 1 +
tests/test_basic_unet.py | 1 +
tests/test_basic_unetplusplus.py | 1 +
tests/test_bending_energy.py | 1 +
tests/test_bilateral_approx_cpu.py | 1 +
tests/test_bilateral_approx_cuda.py | 1 +
tests/test_bilateral_precise.py | 2 ++
tests/test_blend_images.py | 1 +
tests/test_bounding_rect.py | 1 +
tests/test_bounding_rectd.py | 1 +
tests/test_box_coder.py | 1 +
tests/test_box_transform.py | 1 +
tests/test_box_utils.py | 1 +
tests/test_bundle_ckpt_export.py | 1 +
tests/test_bundle_download.py | 3 +++
tests/test_bundle_get_data.py | 1 +
tests/test_bundle_init_bundle.py | 1 +
tests/test_bundle_onnx_export.py | 1 +
tests/test_bundle_push_to_hf_hub.py | 1 +
tests/test_bundle_trt_export.py | 1 +
tests/test_bundle_utils.py | 2 ++
tests/test_bundle_verify_metadata.py | 1 +
tests/test_bundle_verify_net.py | 1 +
tests/test_bundle_workflow.py | 1 +
tests/test_cachedataset.py | 1 +
tests/test_cachedataset_parallel.py | 1 +
tests/test_cachedataset_persistent_workers.py | 1 +
tests/test_cachentransdataset.py | 1 +
tests/test_call_dist.py | 1 +
tests/test_cast_to_type.py | 1 +
tests/test_cast_to_typed.py | 1 +
tests/test_channel_pad.py | 1 +
tests/test_check_hash.py | 1 +
tests/test_check_missing_files.py | 1 +
tests/test_classes_to_indices.py | 1 +
tests/test_classes_to_indicesd.py | 1 +
tests/test_cldice_loss.py | 1 +
tests/test_complex_utils.py | 1 +
tests/test_component_locator.py | 1 +
tests/test_component_store.py | 1 +
tests/test_compose.py | 18 ++++++++++++++++++
tests/test_compose_get_number_conversions.py | 7 +++++++
tests/test_compute_confusion_matrix.py | 1 +
tests/test_compute_f_beta.py | 1 +
tests/test_compute_fid_metric.py | 1 +
tests/test_compute_froc.py | 3 +++
tests/test_compute_generalized_dice.py | 1 +
tests/test_compute_ho_ver_maps.py | 1 +
tests/test_compute_ho_ver_maps_d.py | 1 +
tests/test_compute_meandice.py | 1 +
tests/test_compute_meaniou.py | 1 +
tests/test_compute_mmd_metric.py | 1 +
tests/test_compute_multiscalessim_metric.py | 1 +
tests/test_compute_panoptic_quality.py | 1 +
tests/test_compute_regression_metrics.py | 1 +
tests/test_compute_roc_auc.py | 1 +
tests/test_compute_variance.py | 1 +
tests/test_concat_itemsd.py | 1 +
tests/test_config_item.py | 1 +
tests/test_config_parser.py | 2 ++
tests/test_contrastive_loss.py | 1 +
tests/test_convert_data_type.py | 1 +
tests/test_convert_to_multi_channel.py | 1 +
tests/test_convert_to_multi_channeld.py | 1 +
tests/test_convert_to_onnx.py | 1 +
tests/test_convert_to_torchscript.py | 1 +
tests/test_convert_to_trt.py | 1 +
tests/test_convolutions.py | 3 +++
tests/test_copy_itemsd.py | 1 +
tests/test_copy_model_state.py | 3 +++
tests/test_correct_crop_centers.py | 1 +
tests/test_create_cross_validation_datalist.py | 1 +
tests/test_create_grid_and_affine.py | 2 ++
tests/test_crf_cpu.py | 1 +
tests/test_crf_cuda.py | 1 +
tests/test_crop_foreground.py | 1 +
tests/test_crop_foregroundd.py | 1 +
tests/test_cross_validation.py | 1 +
tests/test_csv_dataset.py | 1 +
tests/test_csv_iterable_dataset.py | 1 +
tests/test_csv_saver.py | 1 +
tests/test_cucim_dict_transform.py | 1 +
tests/test_cucim_transform.py | 1 +
tests/test_cumulative.py | 1 +
tests/test_cumulative_average.py | 1 +
tests/test_cumulative_average_dist.py | 1 +
tests/test_cv2_dist.py | 1 +
tests/test_daf3d.py | 1 +
tests/test_data_stats.py | 1 +
tests/test_data_statsd.py | 1 +
tests/test_dataloader.py | 2 ++
tests/test_dataset.py | 1 +
tests/test_dataset_func.py | 1 +
tests/test_dataset_summary.py | 1 +
tests/test_decathlondataset.py | 1 +
tests/test_decollate.py | 2 ++
tests/test_deepedit_interaction.py | 1 +
tests/test_deepedit_transforms.py | 11 +++++++++++
tests/test_deepgrow_dataset.py | 1 +
tests/test_deepgrow_interaction.py | 1 +
tests/test_deepgrow_transforms.py | 11 +++++++++++
tests/test_delete_itemsd.py | 1 +
tests/test_denseblock.py | 4 ++++
tests/test_densenet.py | 2 ++
tests/test_deprecated.py | 8 ++++++++
tests/test_detect_envelope.py | 2 ++
tests/test_detection_coco_metrics.py | 1 +
tests/test_detector_boxselector.py | 1 +
tests/test_detector_utils.py | 1 +
tests/test_dev_collate.py | 1 +
tests/test_dice_ce_loss.py | 1 +
tests/test_dice_focal_loss.py | 1 +
tests/test_dice_loss.py | 1 +
tests/test_diffusion_loss.py | 1 +
tests/test_dints_cell.py | 1 +
tests/test_dints_mixop.py | 1 +
tests/test_dints_network.py | 2 ++
tests/test_discriminator.py | 1 +
tests/test_distance_transform_edt.py | 1 +
tests/test_download_and_extract.py | 1 +
tests/test_download_url_yandex.py | 1 +
tests/test_downsample_block.py | 1 +
tests/test_drop_path.py | 1 +
tests/test_ds_loss.py | 4 ++++
tests/test_dvf2ddf.py | 1 +
tests/test_dynunet.py | 3 +++
tests/test_dynunet_block.py | 2 ++
tests/test_efficientnet.py | 2 ++
tests/test_ensemble_evaluator.py | 3 +++
tests/test_ensure_channel_first.py | 1 +
tests/test_ensure_channel_firstd.py | 1 +
tests/test_ensure_tuple.py | 1 +
tests/test_ensure_type.py | 1 +
tests/test_ensure_typed.py | 1 +
tests/test_enum_bound_interp.py | 1 +
tests/test_eval_mode.py | 1 +
tests/test_evenly_divisible_all_gather_dist.py | 1 +
tests/test_factorized_increase.py | 1 +
tests/test_factorized_reduce.py | 1 +
tests/test_fastmri_reader.py | 1 +
tests/test_fft_utils.py | 1 +
tests/test_fg_bg_to_indices.py | 1 +
tests/test_fg_bg_to_indicesd.py | 1 +
tests/test_file_basename.py | 1 +
tests/test_fill_holes.py | 1 +
tests/test_fill_holesd.py | 1 +
tests/test_fl_exchange_object.py | 1 +
tests/test_fl_monai_algo.py | 1 +
tests/test_fl_monai_algo_dist.py | 1 +
tests/test_fl_monai_algo_stats.py | 1 +
tests/test_flatten_sub_keysd.py | 1 +
tests/test_flexible_unet.py | 3 +++
tests/test_flip.py | 1 +
tests/test_flipd.py | 1 +
tests/test_focal_loss.py | 1 +
tests/test_folder_layout.py | 1 +
tests/test_foreground_mask.py | 1 +
tests/test_foreground_maskd.py | 1 +
tests/test_fourier.py | 1 +
tests/test_fpn_block.py | 2 ++
tests/test_freeze_layers.py | 1 +
tests/test_from_engine_hovernet.py | 1 +
tests/test_fullyconnectednet.py | 1 +
tests/test_gaussian.py | 1 +
tests/test_gaussian_filter.py | 2 ++
tests/test_gaussian_sharpen.py | 1 +
tests/test_gaussian_sharpend.py | 1 +
tests/test_gaussian_smooth.py | 1 +
tests/test_gaussian_smoothd.py | 1 +
tests/test_gdsdataset.py | 2 ++
tests/test_generalized_dice_focal_loss.py | 1 +
tests/test_generalized_dice_loss.py | 1 +
.../test_generalized_wasserstein_dice_loss.py | 2 ++
tests/test_generate_distance_map.py | 1 +
tests/test_generate_distance_mapd.py | 1 +
tests/test_generate_instance_border.py | 1 +
tests/test_generate_instance_borderd.py | 1 +
tests/test_generate_instance_centroid.py | 1 +
tests/test_generate_instance_centroidd.py | 1 +
tests/test_generate_instance_contour.py | 1 +
tests/test_generate_instance_contourd.py | 1 +
tests/test_generate_instance_type.py | 1 +
tests/test_generate_instance_typed.py | 1 +
...test_generate_label_classes_crop_centers.py | 1 +
tests/test_generate_param_groups.py | 1 +
...test_generate_pos_neg_label_crop_centers.py | 1 +
tests/test_generate_spatial_bounding_box.py | 1 +
tests/test_generate_succinct_contour.py | 1 +
tests/test_generate_succinct_contourd.py | 1 +
tests/test_generate_watershed_markers.py | 1 +
tests/test_generate_watershed_markersd.py | 1 +
tests/test_generate_watershed_mask.py | 1 +
tests/test_generate_watershed_maskd.py | 1 +
tests/test_generator.py | 1 +
tests/test_get_equivalent_dtype.py | 1 +
tests/test_get_extreme_points.py | 1 +
tests/test_get_layers.py | 2 ++
tests/test_get_package_version.py | 1 +
tests/test_get_unique_labels.py | 1 +
tests/test_gibbs_noise.py | 1 +
tests/test_gibbs_noised.py | 1 +
tests/test_giou_loss.py | 1 +
tests/test_global_mutual_information_loss.py | 2 ++
tests/test_globalnet.py | 2 ++
tests/test_gmm.py | 1 +
tests/test_grid_dataset.py | 1 +
tests/test_grid_distortion.py | 1 +
tests/test_grid_distortiond.py | 1 +
tests/test_grid_patch.py | 1 +
tests/test_grid_patchd.py | 1 +
tests/test_grid_pull.py | 1 +
tests/test_grid_split.py | 1 +
tests/test_grid_splitd.py | 1 +
tests/test_handler_checkpoint_loader.py | 1 +
tests/test_handler_checkpoint_saver.py | 1 +
tests/test_handler_classification_saver.py | 1 +
.../test_handler_classification_saver_dist.py | 1 +
tests/test_handler_clearml_image.py | 1 +
tests/test_handler_clearml_stats.py | 1 +
tests/test_handler_confusion_matrix_dist.py | 1 +
tests/test_handler_decollate_batch.py | 1 +
tests/test_handler_early_stop.py | 3 +++
tests/test_handler_garbage_collector.py | 1 +
tests/test_handler_ignite_metric.py | 1 +
tests/test_handler_logfile.py | 1 +
tests/test_handler_lr_scheduler.py | 1 +
tests/test_handler_metric_logger.py | 1 +
tests/test_handler_metrics_reloaded.py | 2 ++
tests/test_handler_metrics_saver.py | 1 +
tests/test_handler_metrics_saver_dist.py | 1 +
tests/test_handler_mlflow.py | 2 ++
tests/test_handler_nvtx.py | 1 +
tests/test_handler_panoptic_quality.py | 1 +
tests/test_handler_parameter_scheduler.py | 3 +++
tests/test_handler_post_processing.py | 1 +
tests/test_handler_prob_map_producer.py | 3 +++
tests/test_handler_regression_metrics.py | 1 +
tests/test_handler_regression_metrics_dist.py | 4 ++++
tests/test_handler_rocauc.py | 1 +
tests/test_handler_rocauc_dist.py | 1 +
tests/test_handler_smartcache.py | 1 +
tests/test_handler_stats.py | 2 ++
tests/test_handler_tb_image.py | 1 +
tests/test_handler_tb_stats.py | 2 ++
tests/test_handler_validation.py | 2 ++
tests/test_hardnegsampler.py | 1 +
tests/test_hashing.py | 2 ++
tests/test_hausdorff_distance.py | 1 +
tests/test_hausdorff_loss.py | 2 ++
tests/test_header_correct.py | 1 +
tests/test_highresnet.py | 1 +
tests/test_hilbert_transform.py | 3 +++
tests/test_histogram_normalize.py | 1 +
tests/test_histogram_normalized.py | 1 +
tests/test_hovernet.py | 1 +
...st_hovernet_instance_map_post_processing.py | 1 +
...t_hovernet_instance_map_post_processingd.py | 1 +
tests/test_hovernet_loss.py | 2 ++
...st_hovernet_nuclear_type_post_processing.py | 1 +
...t_hovernet_nuclear_type_post_processingd.py | 1 +
tests/test_identity.py | 1 +
tests/test_identityd.py | 1 +
tests/test_image_dataset.py | 2 ++
tests/test_image_filter.py | 5 +++++
tests/test_image_rw.py | 4 ++++
tests/test_img2tensorboard.py | 1 +
tests/test_init_reader.py | 1 +
tests/test_integration_autorunner.py | 1 +
tests/test_integration_bundle_run.py | 3 +++
tests/test_integration_classification_2d.py | 2 ++
tests/test_integration_determinism.py | 3 +++
tests/test_integration_fast_train.py | 1 +
tests/test_integration_gpu_customization.py | 1 +
tests/test_integration_lazy_samples.py | 1 +
tests/test_integration_nnunetv2_runner.py | 1 +
tests/test_integration_segmentation_3d.py | 1 +
tests/test_integration_sliding_window.py | 1 +
tests/test_integration_stn.py | 1 +
tests/test_integration_unet_2d.py | 3 +++
tests/test_integration_workers.py | 1 +
tests/test_integration_workflows.py | 3 +++
tests/test_integration_workflows_gan.py | 1 +
tests/test_intensity_stats.py | 1 +
tests/test_intensity_statsd.py | 1 +
tests/test_inverse_array.py | 1 +
tests/test_invert.py | 1 +
tests/test_invertd.py | 1 +
tests/test_is_supported_format.py | 1 +
tests/test_iterable_dataset.py | 2 ++
tests/test_itk_torch_bridge.py | 2 ++
tests/test_itk_writer.py | 1 +
tests/test_k_space_spike_noise.py | 1 +
tests/test_k_space_spike_noised.py | 1 +
tests/test_keep_largest_connected_component.py | 1 +
.../test_keep_largest_connected_componentd.py | 1 +
tests/test_kspace_mask.py | 1 +
tests/test_label_filter.py | 1 +
tests/test_label_filterd.py | 1 +
tests/test_label_quality_score.py | 1 +
tests/test_label_to_contour.py | 1 +
tests/test_label_to_contourd.py | 1 +
tests/test_label_to_mask.py | 1 +
tests/test_label_to_maskd.py | 1 +
tests/test_lambda.py | 1 +
tests/test_lambdad.py | 1 +
tests/test_lesion_froc.py | 1 +
tests/test_list_data_collate.py | 1 +
tests/test_list_to_dict.py | 1 +
tests/test_lltm.py | 1 +
tests/test_lmdbdataset.py | 2 ++
tests/test_lmdbdataset_dist.py | 2 ++
tests/test_load_decathlon_datalist.py | 1 +
tests/test_load_image.py | 2 ++
tests/test_load_imaged.py | 3 +++
tests/test_load_spacing_orientation.py | 1 +
tests/test_loader_semaphore.py | 1 +
..._local_normalized_cross_correlation_loss.py | 1 +
tests/test_localnet.py | 1 +
tests/test_localnet_block.py | 3 +++
tests/test_look_up_option.py | 1 +
tests/test_loss_metric.py | 1 +
tests/test_lr_finder.py | 1 +
tests/test_lr_scheduler.py | 2 ++
tests/test_make_nifti.py | 1 +
tests/test_map_binary_to_indices.py | 1 +
tests/test_map_classes_to_indices.py | 1 +
tests/test_map_label_value.py | 1 +
tests/test_map_label_valued.py | 1 +
tests/test_map_transform.py | 2 ++
tests/test_mask_intensity.py | 1 +
tests/test_mask_intensityd.py | 1 +
tests/test_masked_dice_loss.py | 1 +
tests/test_masked_loss.py | 1 +
tests/test_masked_patch_wsi_dataset.py | 3 +++
tests/test_matshow3d.py | 1 +
tests/test_mean_ensemble.py | 1 +
tests/test_mean_ensembled.py | 1 +
tests/test_median_filter.py | 1 +
tests/test_median_smooth.py | 1 +
tests/test_median_smoothd.py | 1 +
tests/test_mednistdataset.py | 1 +
tests/test_meta_affine.py | 1 +
tests/test_meta_tensor.py | 1 +
tests/test_metatensor_integration.py | 1 +
tests/test_metrics_reloaded.py | 1 +
tests/test_milmodel.py | 1 +
tests/test_mlp.py | 1 +
tests/test_mmar_download.py | 1 +
tests/test_module_list.py | 1 +
tests/test_monai_env_vars.py | 1 +
tests/test_monai_utils_misc.py | 4 ++++
tests/test_mri_utils.py | 1 +
tests/test_multi_scale.py | 1 +
tests/test_net_adapter.py | 1 +
tests/test_network_consistency.py | 1 +
tests/test_nifti_endianness.py | 1 +
tests/test_nifti_header_revise.py | 1 +
tests/test_nifti_rw.py | 1 +
tests/test_normalize_intensity.py | 1 +
tests/test_normalize_intensityd.py | 1 +
tests/test_npzdictitemdataset.py | 1 +
tests/test_nrrd_reader.py | 1 +
tests/test_nuclick_transforms.py | 9 +++++++++
tests/test_numpy_reader.py | 1 +
tests/test_nvtx_decorator.py | 1 +
tests/test_nvtx_transform.py | 1 +
tests/test_occlusion_sensitivity.py | 2 ++
tests/test_one_of.py | 12 ++++++++++++
tests/test_optional_import.py | 1 +
tests/test_ori_ras_lps.py | 1 +
tests/test_orientation.py | 1 +
tests/test_orientationd.py | 1 +
tests/test_p3d_block.py | 1 +
tests/test_pad_collation.py | 2 ++
tests/test_pad_mode.py | 1 +
tests/test_partition_dataset.py | 1 +
tests/test_partition_dataset_classes.py | 1 +
tests/test_patch_dataset.py | 1 +
tests/test_patch_inferer.py | 1 +
tests/test_patch_wsi_dataset.py | 3 +++
tests/test_patchembedding.py | 2 ++
tests/test_pathology_he_stain.py | 2 ++
tests/test_pathology_he_stain_dict.py | 2 ++
tests/test_pathology_prob_nms.py | 1 +
tests/test_perceptual_loss.py | 1 +
tests/test_persistentdataset.py | 2 ++
tests/test_persistentdataset_dist.py | 3 +++
tests/test_phl_cpu.py | 1 +
tests/test_phl_cuda.py | 1 +
tests/test_pil_reader.py | 1 +
tests/test_plot_2d_or_3d_image.py | 1 +
tests/test_png_rw.py | 1 +
tests/test_polyval.py | 1 +
tests/test_prepare_batch_default.py | 2 ++
tests/test_prepare_batch_default_dist.py | 2 ++
tests/test_prepare_batch_extra_input.py | 2 ++
tests/test_prepare_batch_hovernet.py | 2 ++
tests/test_preset_filters.py | 6 ++++++
tests/test_print_info.py | 1 +
tests/test_print_transform_backends.py | 1 +
tests/test_probnms.py | 1 +
tests/test_probnmsd.py | 1 +
tests/test_profiling.py | 1 +
tests/test_pytorch_version_after.py | 1 +
tests/test_query_memory.py | 1 +
tests/test_quicknat.py | 1 +
tests/test_rand_adjust_contrast.py | 1 +
tests/test_rand_adjust_contrastd.py | 1 +
tests/test_rand_affine.py | 1 +
tests/test_rand_affine_grid.py | 1 +
tests/test_rand_affined.py | 1 +
tests/test_rand_axis_flip.py | 1 +
tests/test_rand_axis_flipd.py | 1 +
tests/test_rand_bias_field.py | 1 +
tests/test_rand_bias_fieldd.py | 1 +
tests/test_rand_coarse_dropout.py | 1 +
tests/test_rand_coarse_dropoutd.py | 1 +
tests/test_rand_coarse_shuffle.py | 1 +
tests/test_rand_coarse_shuffled.py | 1 +
tests/test_rand_crop_by_label_classes.py | 1 +
tests/test_rand_crop_by_label_classesd.py | 1 +
tests/test_rand_crop_by_pos_neg_label.py | 1 +
tests/test_rand_crop_by_pos_neg_labeld.py | 1 +
tests/test_rand_cucim_dict_transform.py | 1 +
tests/test_rand_cucim_transform.py | 1 +
tests/test_rand_deform_grid.py | 1 +
tests/test_rand_elastic_2d.py | 1 +
tests/test_rand_elastic_3d.py | 1 +
tests/test_rand_elasticd_2d.py | 1 +
tests/test_rand_elasticd_3d.py | 1 +
tests/test_rand_flip.py | 1 +
tests/test_rand_flipd.py | 1 +
tests/test_rand_gaussian_noise.py | 1 +
tests/test_rand_gaussian_noised.py | 1 +
tests/test_rand_gaussian_sharpen.py | 1 +
tests/test_rand_gaussian_sharpend.py | 1 +
tests/test_rand_gaussian_smooth.py | 1 +
tests/test_rand_gaussian_smoothd.py | 1 +
tests/test_rand_gibbs_noise.py | 1 +
tests/test_rand_gibbs_noised.py | 1 +
tests/test_rand_grid_distortion.py | 1 +
tests/test_rand_grid_distortiond.py | 1 +
tests/test_rand_grid_patch.py | 1 +
tests/test_rand_grid_patchd.py | 1 +
tests/test_rand_histogram_shift.py | 1 +
tests/test_rand_histogram_shiftd.py | 1 +
tests/test_rand_k_space_spike_noise.py | 1 +
tests/test_rand_k_space_spike_noised.py | 1 +
tests/test_rand_lambda.py | 1 +
tests/test_rand_lambdad.py | 1 +
tests/test_rand_rician_noise.py | 1 +
tests/test_rand_rician_noised.py | 1 +
tests/test_rand_rotate.py | 3 +++
tests/test_rand_rotate90.py | 1 +
tests/test_rand_rotate90d.py | 1 +
tests/test_rand_rotated.py | 2 ++
tests/test_rand_scale_intensity.py | 1 +
tests/test_rand_scale_intensity_fixed_mean.py | 1 +
tests/test_rand_scale_intensity_fixed_meand.py | 1 +
tests/test_rand_scale_intensityd.py | 1 +
tests/test_rand_shift_intensity.py | 1 +
tests/test_rand_shift_intensityd.py | 1 +
tests/test_rand_simulate_low_resolution.py | 1 +
tests/test_rand_simulate_low_resolutiond.py | 1 +
tests/test_rand_spatial_crop_samplesd.py | 1 +
tests/test_rand_std_shift_intensity.py | 1 +
tests/test_rand_std_shift_intensityd.py | 1 +
tests/test_rand_weighted_cropd.py | 1 +
tests/test_rand_zoom.py | 1 +
tests/test_rand_zoomd.py | 1 +
tests/test_randidentity.py | 2 ++
tests/test_random_order.py | 4 ++++
tests/test_randomizable.py | 2 ++
tests/test_randomizable_transform_type.py | 2 ++
tests/test_randtorchvisiond.py | 1 +
tests/test_rankfilter_dist.py | 2 ++
tests/test_recon_net_utils.py | 1 +
...test_reference_based_normalize_intensity.py | 1 +
tests/test_reference_based_spatial_cropd.py | 1 +
tests/test_reference_resolver.py | 1 +
tests/test_reg_loss_integration.py | 2 ++
tests/test_regunet.py | 1 +
tests/test_regunet_block.py | 3 +++
tests/test_remove_repeated_channel.py | 1 +
tests/test_remove_repeated_channeld.py | 1 +
tests/test_remove_small_objects.py | 1 +
tests/test_repeat_channel.py | 1 +
tests/test_repeat_channeld.py | 1 +
tests/test_replace_module.py | 1 +
tests/test_require_pkg.py | 4 ++++
tests/test_resample.py | 1 +
tests/test_resample_backends.py | 1 +
tests/test_resample_datalist.py | 1 +
tests/test_resample_to_match.py | 1 +
tests/test_resample_to_matchd.py | 1 +
tests/test_resampler.py | 1 +
tests/test_resize.py | 1 +
tests/test_resize_with_pad_or_crop.py | 1 +
tests/test_resize_with_pad_or_cropd.py | 1 +
tests/test_resized.py | 1 +
tests/test_resnet.py | 1 +
tests/test_retinanet.py | 1 +
tests/test_retinanet_detector.py | 2 ++
tests/test_retinanet_predict_utils.py | 3 +++
tests/test_rotate.py | 2 ++
tests/test_rotate90.py | 3 +++
tests/test_rotate90d.py | 1 +
tests/test_rotated.py | 3 +++
tests/test_safe_dtype_range.py | 1 +
tests/test_saliency_inferer.py | 1 +
tests/test_sample_slices.py | 1 +
tests/test_sampler_dist.py | 1 +
tests/test_save_classificationd.py | 1 +
tests/test_save_image.py | 1 +
tests/test_save_imaged.py | 3 +++
tests/test_save_state.py | 1 +
tests/test_savitzky_golay_filter.py | 4 ++++
tests/test_savitzky_golay_smooth.py | 1 +
tests/test_savitzky_golay_smoothd.py | 1 +
tests/test_scale_intensity.py | 1 +
tests/test_scale_intensity_fixed_mean.py | 1 +
tests/test_scale_intensity_range.py | 1 +
.../test_scale_intensity_range_percentiles.py | 1 +
.../test_scale_intensity_range_percentilesd.py | 1 +
tests/test_scale_intensity_ranged.py | 1 +
tests/test_scale_intensityd.py | 1 +
tests/test_se_block.py | 1 +
tests/test_se_blocks.py | 2 ++
tests/test_seg_loss_integration.py | 2 ++
tests/test_segresnet.py | 2 ++
tests/test_segresnet_block.py | 1 +
tests/test_segresnet_ds.py | 1 +
tests/test_select_cross_validation_folds.py | 1 +
tests/test_select_itemsd.py | 1 +
tests/test_selfattention.py | 1 +
tests/test_senet.py | 2 ++
tests/test_separable_filter.py | 1 +
tests/test_set_determinism.py | 2 ++
tests/test_set_visible_devices.py | 1 +
tests/test_shift_intensity.py | 1 +
tests/test_shift_intensityd.py | 1 +
tests/test_shuffle_buffer.py | 1 +
tests/test_signal_continuouswavelet.py | 1 +
tests/test_signal_fillempty.py | 2 ++
tests/test_signal_fillemptyd.py | 2 ++
tests/test_signal_rand_add_gaussiannoise.py | 2 ++
tests/test_signal_rand_add_sine.py | 2 ++
tests/test_signal_rand_add_sine_partial.py | 2 ++
tests/test_signal_rand_add_squarepulse.py | 2 ++
...test_signal_rand_add_squarepulse_partial.py | 2 ++
tests/test_signal_rand_drop.py | 2 ++
tests/test_signal_rand_scale.py | 2 ++
tests/test_signal_rand_shift.py | 2 ++
tests/test_signal_remove_frequency.py | 2 ++
tests/test_simple_aspp.py | 1 +
tests/test_simulatedelay.py | 1 +
tests/test_simulatedelayd.py | 1 +
tests/test_skip_connection.py | 1 +
tests/test_slice_inferer.py | 1 +
tests/test_sliding_patch_wsi_dataset.py | 3 +++
.../test_sliding_window_hovernet_inference.py | 1 +
tests/test_sliding_window_inference.py | 2 ++
tests/test_sliding_window_splitter.py | 1 +
tests/test_smartcachedataset.py | 1 +
tests/test_smooth_field.py | 1 +
tests/test_some_of.py | 6 ++++++
tests/test_spacing.py | 1 +
tests/test_spacingd.py | 1 +
tests/test_spatial_combine_transforms.py | 1 +
tests/test_spatial_resample.py | 1 +
tests/test_spatial_resampled.py | 1 +
tests/test_spectral_loss.py | 1 +
tests/test_splitdim.py | 1 +
tests/test_squeeze_unsqueeze.py | 1 +
tests/test_squeezedim.py | 1 +
tests/test_squeezedimd.py | 1 +
tests/test_ssim_loss.py | 1 +
tests/test_ssim_metric.py | 1 +
tests/test_state_cacher.py | 1 +
tests/test_std_shift_intensity.py | 1 +
tests/test_std_shift_intensityd.py | 1 +
tests/test_str2bool.py | 1 +
tests/test_str2list.py | 1 +
tests/test_subpixel_upsample.py | 1 +
tests/test_surface_dice.py | 1 +
tests/test_surface_distance.py | 1 +
tests/test_swin_unetr.py | 1 +
tests/test_synthetic.py | 1 +
tests/test_tciadataset.py | 1 +
tests/test_testtimeaugmentation.py | 1 +
tests/test_text_encoding.py | 1 +
tests/test_thread_buffer.py | 1 +
tests/test_threadcontainer.py | 1 +
tests/test_threshold_intensity.py | 1 +
tests/test_threshold_intensityd.py | 1 +
tests/test_timedcall_dist.py | 1 +
tests/test_to_contiguous.py | 1 +
tests/test_to_cupy.py | 1 +
tests/test_to_cupyd.py | 1 +
tests/test_to_device.py | 1 +
tests/test_to_deviced.py | 1 +
tests/test_to_from_meta_tensord.py | 1 +
tests/test_to_numpy.py | 1 +
tests/test_to_numpyd.py | 1 +
tests/test_to_onehot.py | 1 +
tests/test_to_pil.py | 1 +
tests/test_to_pild.py | 1 +
tests/test_to_tensor.py | 1 +
tests/test_to_tensord.py | 1 +
tests/test_torchscript_utils.py | 2 ++
tests/test_torchvision.py | 1 +
tests/test_torchvision_fc_model.py | 2 ++
tests/test_torchvisiond.py | 1 +
tests/test_traceable_transform.py | 2 ++
tests/test_train_mode.py | 1 +
tests/test_trainable_bilateral.py | 2 ++
tests/test_trainable_joint_bilateral.py | 2 ++
tests/test_transchex.py | 1 +
tests/test_transform.py | 2 ++
tests/test_transformerblock.py | 1 +
tests/test_transpose.py | 1 +
tests/test_transposed.py | 1 +
tests/test_tversky_loss.py | 1 +
...test_ultrasound_confidence_map_transform.py | 1 +
tests/test_unet.py | 1 +
tests/test_unetr.py | 1 +
tests/test_unetr_block.py | 3 +++
tests/test_unified_focal_loss.py | 1 +
tests/test_upsample_block.py | 1 +
tests/test_utils_pytorch_numpy_unification.py | 1 +
tests/test_varautoencoder.py | 1 +
tests/test_varnet.py | 1 +
tests/test_version.py | 1 +
tests/test_video_datasets.py | 1 +
tests/test_vis_cam.py | 1 +
tests/test_vis_gradbased.py | 2 ++
tests/test_vis_gradcam.py | 2 ++
tests/test_vit.py | 1 +
tests/test_vitautoenc.py | 1 +
tests/test_vnet.py | 1 +
tests/test_vote_ensemble.py | 1 +
tests/test_vote_ensembled.py | 1 +
tests/test_voxelmorph.py | 1 +
tests/test_warp.py | 1 +
tests/test_watershed.py | 1 +
tests/test_watershedd.py | 1 +
tests/test_weight_init.py | 1 +
tests/test_weighted_random_sampler_dist.py | 1 +
tests/test_with_allow_missing_keys.py | 1 +
tests/test_write_metrics_reports.py | 1 +
tests/test_wsi_sliding_window_splitter.py | 2 ++
tests/test_wsireader.py | 4 ++++
tests/test_zarr_avg_merger.py | 1 +
tests/test_zipdataset.py | 2 ++
tests/test_zoom.py | 1 +
tests/test_zoom_affine.py | 1 +
tests/test_zoomd.py | 1 +
tests/utils.py | 2 ++
734 files changed, 1046 insertions(+), 4 deletions(-)
diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py
index bb10eb6b11..67ea3059cc 100644
--- a/monai/apps/datasets.py
+++ b/monai/apps/datasets.py
@@ -737,6 +737,7 @@ def get_dataset(self, folds: Sequence[int] | int, **dataset_params: Any) -> obje
dataset_params_.update(dataset_params)
class _NsplitsDataset(self.dataset_cls): # type: ignore
+
def _split_datalist(self, datalist: list[dict]) -> list[dict]:
data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=True, seed=seed)
return select_cross_validation_folds(partitions=data, folds=folds)
diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py
index 4c92b42c08..6d0825f54a 100644
--- a/monai/apps/deepedit/transforms.py
+++ b/monai/apps/deepedit/transforms.py
@@ -34,6 +34,7 @@
class DiscardAddGuidanced(MapTransform):
+
def __init__(
self,
keys: KeysCollection,
@@ -84,6 +85,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
class NormalizeLabelsInDatasetd(MapTransform):
+
def __init__(
self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False
):
@@ -121,6 +123,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
class SingleLabelSelectiond(MapTransform):
+
def __init__(
self, keys: KeysCollection, label_names: Sequence[str] | None = None, allow_missing_keys: bool = False
):
diff --git a/monai/apps/detection/metrics/coco.py b/monai/apps/detection/metrics/coco.py
index 033b763be5..a856f14fb8 100644
--- a/monai/apps/detection/metrics/coco.py
+++ b/monai/apps/detection/metrics/coco.py
@@ -72,6 +72,7 @@
class COCOMetric:
+
def __init__(
self,
classes: Sequence[str],
diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py
index cc9e238862..5b8f950ab3 100644
--- a/monai/apps/detection/utils/ATSS_matcher.py
+++ b/monai/apps/detection/utils/ATSS_matcher.py
@@ -164,6 +164,7 @@ def compute_matches(
class ATSSMatcher(Matcher):
+
def __init__(
self,
num_candidates: int = 4,
diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py
index ef9c535725..99e94f89c0 100644
--- a/monai/apps/pathology/transforms/post/array.py
+++ b/monai/apps/pathology/transforms/post/array.py
@@ -162,7 +162,7 @@ def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor:
pred = label(pred)[0]
if self.remove_small_objects is not None:
pred = self.remove_small_objects(pred)
- pred[pred > 0] = 1 # type: ignore
+ pred[pred > 0] = 1
return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]
@@ -338,7 +338,7 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N
instance_border = instance_border >= self.threshold # uncertain area
marker = mask - convert_to_dst_type(instance_border, mask)[0] # certain foreground
- marker[marker < 0] = 0 # type: ignore
+ marker[marker < 0] = 0
marker = self.postprocess_fn(marker)
marker = convert_to_numpy(marker)
@@ -634,7 +634,7 @@ def __call__( # type: ignore
seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]
- inst_type = type_map_crop[seg_map_crop] # type: ignore
+ inst_type = type_map_crop[seg_map_crop]
type_list, type_pixels = unique(inst_type, return_counts=True)
type_list = list(zip(type_list, type_pixels))
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
diff --git a/monai/data/dataset.py b/monai/data/dataset.py
index eba850225d..531893d768 100644
--- a/monai/data/dataset.py
+++ b/monai/data/dataset.py
@@ -1275,6 +1275,7 @@ def __len__(self) -> int:
return min(len(dataset) for dataset in self.data)
def _transform(self, index: int):
+
def to_list(x):
return list(x) if isinstance(x, (tuple, list)) else [x]
diff --git a/monai/fl/client/client_algo.py b/monai/fl/client/client_algo.py
index 25a88a9e66..3dc9f5785d 100644
--- a/monai/fl/client/client_algo.py
+++ b/monai/fl/client/client_algo.py
@@ -57,6 +57,7 @@ def abort(self, extra: dict | None = None) -> None:
class ClientAlgoStats(BaseClient):
+
def get_data_stats(self, extra: dict | None = None) -> ExchangeObject:
"""
Get summary statistics about the local data.
diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py
index 0382b8cb64..021154d705 100644
--- a/monai/handlers/ignite_metric.py
+++ b/monai/handlers/ignite_metric.py
@@ -157,6 +157,7 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
@deprecated(since="1.2", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.")
class IgniteMetric(IgniteMetricHandler):
+
def __init__(
self,
metric_fn: CumulativeIterationMetric | None = None,
diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py
index 61e4525662..bb9371c8bf 100644
--- a/monai/metrics/f_beta_score.py
+++ b/monai/metrics/f_beta_score.py
@@ -22,6 +22,7 @@
class FBetaScore(CumulativeIterationMetric):
+
def __init__(
self,
beta: float = 1.0,
diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py
index 12afab3464..801b49de8b 100644
--- a/monai/networks/blocks/dynunet_block.py
+++ b/monai/networks/blocks/dynunet_block.py
@@ -245,6 +245,7 @@ def forward(self, inp, skip):
class UnetOutBlock(nn.Module):
+
def __init__(
self, spatial_dims: int, in_channels: int, out_channels: int, dropout: tuple | str | float | None = None
):
diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py
index 11808eabf7..6e0efc8588 100644
--- a/monai/networks/blocks/localnet_block.py
+++ b/monai/networks/blocks/localnet_block.py
@@ -72,6 +72,7 @@ def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) ->
class ResidualBlock(nn.Module):
+
def __init__(
self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int
) -> None:
@@ -95,6 +96,7 @@ def forward(self, x) -> torch.Tensor:
class LocalNetResidualBlock(nn.Module):
+
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
super().__init__()
if in_channels != out_channels:
diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py
index 138149cac6..e03553307e 100644
--- a/monai/networks/blocks/pos_embed_utils.py
+++ b/monai/networks/blocks/pos_embed_utils.py
@@ -23,6 +23,7 @@
# From PyTorch internals
def _ntuple(n):
+
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
diff --git a/monai/networks/layers/gmm.py b/monai/networks/layers/gmm.py
index 94d619bb7a..6ebe66832f 100644
--- a/monai/networks/layers/gmm.py
+++ b/monai/networks/layers/gmm.py
@@ -78,6 +78,7 @@ def apply(self, features):
class _ApplyFunc(torch.autograd.Function):
+
@staticmethod
def forward(ctx, params, features, compiled_extension):
return compiled_extension.apply(params, features)
diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py
index a1122ceaa2..4ac621967f 100644
--- a/monai/networks/layers/simplelayers.py
+++ b/monai/networks/layers/simplelayers.py
@@ -552,6 +552,7 @@ def forward(self, in_tensor: torch.Tensor, number_of_passes=1) -> torch.Tensor:
class GaussianFilter(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -607,6 +608,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class LLTMFunction(Function):
+
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = _C.lltm_forward(input, weights, bias, old_h, old_cell)
diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py
index 53f35e63f2..2d39dfdbc1 100644
--- a/monai/networks/layers/spatial_transforms.py
+++ b/monai/networks/layers/spatial_transforms.py
@@ -33,6 +33,7 @@
class _GridPull(torch.autograd.Function):
+
@staticmethod
def forward(ctx, input, grid, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
@@ -132,6 +133,7 @@ def grid_pull(
class _GridPush(torch.autograd.Function):
+
@staticmethod
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
@@ -236,6 +238,7 @@ def grid_push(
class _GridCount(torch.autograd.Function):
+
@staticmethod
def forward(ctx, grid, shape, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
@@ -335,6 +338,7 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze
class _GridGrad(torch.autograd.Function):
+
@staticmethod
def forward(ctx, input, grid, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
@@ -433,6 +437,7 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b
class AffineTransform(nn.Module):
+
def __init__(
self,
spatial_size: Sequence[int] | int | None = None,
diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py
index b0ad1eabbd..5e280d7f24 100644
--- a/monai/networks/nets/ahnet.py
+++ b/monai/networks/nets/ahnet.py
@@ -87,6 +87,7 @@ def forward(self, x):
class Projection(nn.Sequential):
+
def __init__(self, spatial_dims: int, num_input_features: int, num_output_features: int):
super().__init__()
@@ -100,6 +101,7 @@ def __init__(self, spatial_dims: int, num_input_features: int, num_output_featur
class DenseBlock(nn.Sequential):
+
def __init__(
self,
spatial_dims: int,
@@ -118,6 +120,7 @@ def __init__(
class UpTransition(nn.Sequential):
+
def __init__(
self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose"
):
@@ -143,6 +146,7 @@ def __init__(
class Final(nn.Sequential):
+
def __init__(
self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = "transpose"
):
@@ -178,6 +182,7 @@ def __init__(
class Pseudo3DLayer(nn.Module):
+
def __init__(self, spatial_dims: int, num_input_features: int, growth_rate: int, bn_size: int, dropout_prob: float):
super().__init__()
# 1x1x1
@@ -244,6 +249,7 @@ def forward(self, x):
class PSP(nn.Module):
+
def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_mode: str = "transpose"):
super().__init__()
self.up_modules = nn.ModuleList()
diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py
index 362d63d636..5689cf1071 100644
--- a/monai/networks/nets/attentionunet.py
+++ b/monai/networks/nets/attentionunet.py
@@ -23,6 +23,7 @@
class ConvBlock(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -67,6 +68,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class UpConv(nn.Module):
+
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0):
super().__init__()
self.up = Convolution(
@@ -88,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class AttentionBlock(nn.Module):
+
def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0):
super().__init__()
self.W_g = nn.Sequential(
@@ -145,6 +148,7 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
class AttentionLayer(nn.Module):
+
def __init__(
self,
spatial_dims: int,
diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py
index 7fc57edc42..b9970d4113 100644
--- a/monai/networks/nets/basic_unet.py
+++ b/monai/networks/nets/basic_unet.py
@@ -176,6 +176,7 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
class BasicUNet(nn.Module):
+
def __init__(
self,
spatial_dims: int = 3,
diff --git a/monai/networks/nets/basic_unetplusplus.py b/monai/networks/nets/basic_unetplusplus.py
index 28d4b4668a..f7ae768513 100644
--- a/monai/networks/nets/basic_unetplusplus.py
+++ b/monai/networks/nets/basic_unetplusplus.py
@@ -24,6 +24,7 @@
class BasicUNetPlusPlus(nn.Module):
+
def __init__(
self,
spatial_dims: int = 3,
diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py
index 2100272d91..5ccb429c91 100644
--- a/monai/networks/nets/densenet.py
+++ b/monai/networks/nets/densenet.py
@@ -42,6 +42,7 @@
class _DenseLayer(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -88,6 +89,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class _DenseBlock(nn.Sequential):
+
def __init__(
self,
spatial_dims: int,
@@ -119,6 +121,7 @@ def __init__(
class _Transition(nn.Sequential):
+
def __init__(
self,
spatial_dims: int,
diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py
index 6e3420d136..129e0925d3 100644
--- a/monai/networks/nets/dints.py
+++ b/monai/networks/nets/dints.py
@@ -73,6 +73,7 @@ def _dfs(node, paths):
class _IdentityWithRAMCost(nn.Identity):
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ram_cost = 0
@@ -105,6 +106,7 @@ def __init__(
class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock):
+
def __init__(
self,
in_channel: int,
@@ -122,6 +124,7 @@ def __init__(
class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock):
+
def __init__(
self,
in_channel: int,
@@ -138,6 +141,7 @@ def __init__(
class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock):
+
def __init__(
self,
in_channel: int,
diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py
index d89ab53ea2..4e6c327b23 100644
--- a/monai/networks/nets/efficientnet.py
+++ b/monai/networks/nets/efficientnet.py
@@ -73,6 +73,7 @@
class MBConvBlock(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -227,6 +228,7 @@ def set_swish(self, memory_efficient: bool = True) -> None:
class EfficientNet(nn.Module):
+
def __init__(
self,
blocks_args_str: list[str],
@@ -472,6 +474,7 @@ def _initialize_weights(self) -> None:
class EfficientNetBN(EfficientNet):
+
def __init__(
self,
model_name: str,
@@ -558,6 +561,7 @@ def __init__(
class EfficientNetBNFeatures(EfficientNet):
+
def __init__(
self,
model_name: str,
diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py
index e71f8d193d..4959f0713f 100644
--- a/monai/networks/nets/highresnet.py
+++ b/monai/networks/nets/highresnet.py
@@ -36,6 +36,7 @@
class HighResBlock(nn.Module):
+
def __init__(
self,
spatial_dims: int,
diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py
index 3ec1cea37e..5f340c9be6 100644
--- a/monai/networks/nets/hovernet.py
+++ b/monai/networks/nets/hovernet.py
@@ -49,6 +49,7 @@
class _DenseLayerDecoder(nn.Module):
+
def __init__(
self,
num_features: int,
@@ -103,6 +104,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class _DecoderBlock(nn.Sequential):
+
def __init__(
self,
layers: int,
@@ -159,6 +161,7 @@ def __init__(
class _DenseLayer(nn.Sequential):
+
def __init__(
self,
num_features: int,
@@ -219,6 +222,7 @@ def __init__(
class _Transition(nn.Sequential):
+
def __init__(
self, in_channels: int, act: str | tuple = ("relu", {"inplace": True}), norm: str | tuple = "batch"
) -> None:
@@ -235,6 +239,7 @@ def __init__(
class _ResidualBlock(nn.Module):
+
def __init__(
self,
layers: int,
@@ -312,6 +317,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class _DecoderBranch(nn.ModuleList):
+
def __init__(
self,
decode_config: Sequence[int] = (8, 4),
diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py
index 0a25b7feec..ad6b77bf3d 100644
--- a/monai/networks/nets/milmodel.py
+++ b/monai/networks/nets/milmodel.py
@@ -83,6 +83,7 @@ def __init__(
if mil_mode == "att_trans_pyramid":
# register hooks to capture outputs of intermediate layers
def forward_hook(layer_name):
+
def hook(module, input, output):
self.extra_outputs[layer_name] = output
diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py
index a7c5158240..4d6150ea1b 100644
--- a/monai/networks/nets/regunet.py
+++ b/monai/networks/nets/regunet.py
@@ -234,6 +234,7 @@ def forward(self, x):
class AffineHead(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -375,6 +376,7 @@ def build_output_block(self):
class AdditiveUpSampleBlock(nn.Module):
+
def __init__(
self,
spatial_dims: int,
diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py
index d89eb8ae03..2815224e08 100644
--- a/monai/networks/nets/vnet.py
+++ b/monai/networks/nets/vnet.py
@@ -30,6 +30,7 @@ def get_acti_layer(act: tuple[str, dict] | str, nchan: int = 0):
class LUConv(nn.Module):
+
def __init__(self, spatial_dims: int, nchan: int, act: tuple[str, dict] | str, bias: bool = False):
super().__init__()
@@ -58,6 +59,7 @@ def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: tuple[str, dict]
class InputTransition(nn.Module):
+
def __init__(
self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False
):
@@ -91,6 +93,7 @@ def forward(self, x):
class DownTransition(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -127,6 +130,7 @@ def forward(self, x):
class UpTransition(nn.Module):
+
def __init__(
self,
spatial_dims: int,
@@ -165,6 +169,7 @@ def forward(self, x, skipx):
class OutputTransition(nn.Module):
+
def __init__(
self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False
):
diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py
index 75e108ae71..045135628d 100644
--- a/monai/optimizers/lr_finder.py
+++ b/monai/optimizers/lr_finder.py
@@ -43,6 +43,7 @@
class DataLoaderIter:
+
def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None:
if not isinstance(data_loader, DataLoader):
raise ValueError(
@@ -71,6 +72,7 @@ def __next__(self):
class TrainDataLoaderIter(DataLoaderIter):
+
def __init__(
self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True
) -> None:
diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py
index 7e566abb46..75a125f076 100644
--- a/monai/optimizers/utils.py
+++ b/monai/optimizers/utils.py
@@ -70,12 +70,14 @@ def generate_param_groups(
lr_values = ensure_tuple_rep(lr_values, len(layer_matches))
def _get_select(f):
+
def _select():
return f(network).parameters()
return _select
def _get_filter(f):
+
def _filter():
# should eventually generate a list of network parameters
return (x[1] for x in filter(f, network.named_parameters()))
diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py
index 5729740690..f5f1a4fc18 100644
--- a/monai/transforms/adaptors.py
+++ b/monai/transforms/adaptors.py
@@ -132,6 +132,7 @@ def __call__(self, img, seg):
@_monai_export("monai.transforms")
def adaptor(function, outputs, inputs=None):
+
def must_be_types_or_none(variable_name, variable, types):
if variable is not None:
if not isinstance(variable, types):
@@ -216,6 +217,7 @@ def _inner(ditems):
@_monai_export("monai.transforms")
def apply_alias(fn, name_map):
+
def _inner(data):
# map names
pre_call = dict(data)
@@ -236,6 +238,7 @@ def _inner(data):
@_monai_export("monai.transforms")
def to_kwargs(fn):
+
def _inner(data):
return fn(**data)
@@ -243,6 +246,7 @@ def _inner(data):
class FunctionSignature:
+
def __init__(self, function: Callable) -> None:
import inspect
diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py
index 73149f1be5..1a7d16fb8c 100644
--- a/monai/transforms/inverse_batch_transform.py
+++ b/monai/transforms/inverse_batch_transform.py
@@ -30,6 +30,7 @@
class _BatchInverseDataset(Dataset):
+
def __init__(self, data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool) -> None:
self.data = data
self.invertible_transform = transform
diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py
index 3d5d30be92..da9b23ce57 100644
--- a/monai/transforms/post/array.py
+++ b/monai/transforms/post/array.py
@@ -631,6 +631,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
class Ensemble:
+
@staticmethod
def get_stacked_torch(img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> torch.Tensor:
"""Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor."""
diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py
index 8ad86b72dd..094afdd3c4 100644
--- a/monai/transforms/spatial/array.py
+++ b/monai/transforms/spatial/array.py
@@ -3441,7 +3441,7 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl
idx = self.R.permutation(image_np.shape[0])
idx = idx[: self.num_patches]
idx_np = convert_data_type(idx, np.ndarray)[0]
- image_np = image_np[idx] # type: ignore
+ image_np = image_np[idx]
locations = locations[idx_np]
return image_np, locations
elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX):
diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py
index 1cd9ff6323..7e3a7b0454 100644
--- a/monai/transforms/utility/dictionary.py
+++ b/monai/transforms/utility/dictionary.py
@@ -670,6 +670,7 @@ def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Sequence[bool]
self.use_re = ensure_tuple_rep(use_re, len(self.keys))
def __call__(self, data):
+
def _delete_item(keys, d, use_re: bool = False):
key = keys[0]
if len(keys) > 1:
diff --git a/monai/utils/misc.py b/monai/utils/misc.py
index 2a5c5da136..caa7c067df 100644
--- a/monai/utils/misc.py
+++ b/monai/utils/misc.py
@@ -742,6 +742,7 @@ def check_key_duplicates(ordered_pairs: Sequence[tuple[Any, Any]]) -> dict[Any,
class CheckKeyDuplicatesYamlLoader(SafeLoader):
+
def construct_mapping(self, node, deep=False):
mapping = set()
for key_node, _ in node.value:
diff --git a/monai/utils/module.py b/monai/utils/module.py
index f46ba7c1b3..db62e1e72b 100644
--- a/monai/utils/module.py
+++ b/monai/utils/module.py
@@ -418,6 +418,7 @@ def optional_import(
msg += f" ({exception_str})"
class _LazyRaise:
+
def __init__(self, *_args, **_kwargs):
_default_msg = (
f"{msg}."
@@ -453,6 +454,7 @@ def __iter__(self):
return _LazyRaise(), False
class _LazyCls(_LazyRaise):
+
def __init__(self, *_args, **kwargs):
super().__init__()
if not as_type.startswith("decorator"):
diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py
index da5c0ac05c..5c880bbe1f 100644
--- a/monai/utils/profiling.py
+++ b/monai/utils/profiling.py
@@ -336,6 +336,7 @@ def profile_iter(self, name, iterable):
"""Wrapper around anything iterable to profile how long it takes to generate items."""
class _Iterable:
+
def __iter__(_self): # noqa: B902, N805 pylint: disable=E0213
do_iter = True
orig_iter = iter(iterable)
diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py
index 81d0bb32c4..6d1e8dfd03 100644
--- a/monai/visualize/class_activation_maps.py
+++ b/monai/visualize/class_activation_maps.py
@@ -96,12 +96,14 @@ def __init__(
warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.")
def backward_hook(self, name):
+
def _hook(_module, _grad_input, grad_output):
self.gradients[name] = grad_output[0]
return _hook
def forward_hook(self, name):
+
def _hook(_module, _input, output):
self.activations[name] = output
diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py
index e2e938d86b..c54c9cd4ca 100644
--- a/monai/visualize/gradient_based.py
+++ b/monai/visualize/gradient_based.py
@@ -26,6 +26,7 @@
class _AutoGradReLU(torch.autograd.Function):
+
@staticmethod
def forward(ctx, x):
pos_mask = (x > 0).type_as(x)
diff --git a/tests/croppers.py b/tests/croppers.py
index 8c9b43bf0a..cfececfa9f 100644
--- a/tests/croppers.py
+++ b/tests/croppers.py
@@ -24,6 +24,7 @@
class CropTest(unittest.TestCase):
+
@staticmethod
def get_arr(shape):
return np.random.randint(100, size=shape).astype(float)
diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py
index c7baac2bc9..78c6ca06bc 100644
--- a/tests/hvd_evenly_divisible_all_gather.py
+++ b/tests/hvd_evenly_divisible_all_gather.py
@@ -21,6 +21,7 @@
class HvdEvenlyDivisibleAllGather:
+
def test_data(self):
# initialize Horovod
hvd.init()
diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py
index ba35f2b80c..01dc044870 100644
--- a/tests/ngc_bundle_download.py
+++ b/tests/ngc_bundle_download.py
@@ -70,6 +70,7 @@
@skip_if_windows
class TestNgcBundleDownload(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2])
@skip_if_quick
def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val):
@@ -101,6 +102,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
@unittest.skip("deprecating mmar tests")
class TestAllDownloadingMMAR(unittest.TestCase):
+
def setUp(self):
print_debug_info()
self.test_dir = "./"
diff --git a/tests/padders.py b/tests/padders.py
index ae1153bdfd..a7dce263bb 100644
--- a/tests/padders.py
+++ b/tests/padders.py
@@ -51,6 +51,7 @@
class PadTest(unittest.TestCase):
+
@staticmethod
def get_arr(shape):
return np.random.randint(100, size=shape).astype(float)
diff --git a/tests/profile_subclass/min_classes.py b/tests/profile_subclass/min_classes.py
index 7104ffcd59..3e7c52476f 100644
--- a/tests/profile_subclass/min_classes.py
+++ b/tests/profile_subclass/min_classes.py
@@ -25,5 +25,6 @@ class SubTensor(torch.Tensor):
class SubWithTorchFunc(torch.Tensor):
+
def __torch_function__(self, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, {} if kwargs is None else kwargs)
diff --git a/tests/test_acn_block.py b/tests/test_acn_block.py
index 2f3783cbb8..1cbf3ea168 100644
--- a/tests/test_acn_block.py
+++ b/tests/test_acn_block.py
@@ -29,6 +29,7 @@
class TestACNBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_acn_block(self, input_param, input_shape, expected_shape):
net = ActiConvNormBlock(**input_param)
diff --git a/tests/test_activations.py b/tests/test_activations.py
index 0e83c73304..ad18e2bbec 100644
--- a/tests/test_activations.py
+++ b/tests/test_activations.py
@@ -94,6 +94,7 @@
class TestActivations(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_value_shape(self, input_param, img, out, expected_shape):
result = Activations(**input_param)(img)
diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py
index 22a275997c..74968c0bb4 100644
--- a/tests/test_activationsd.py
+++ b/tests/test_activationsd.py
@@ -50,6 +50,7 @@
class TestActivationsd(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_value_shape(self, input_param, test_input, output, expected_shape):
result = Activationsd(**input_param)(test_input)
diff --git a/tests/test_adaptors.py b/tests/test_adaptors.py
index 257c4346ad..2495fdc72e 100644
--- a/tests/test_adaptors.py
+++ b/tests/test_adaptors.py
@@ -18,13 +18,16 @@
class TestAdaptors(unittest.TestCase):
+
def test_function_signature(self):
+
def foo(image, label=None, *a, **kw):
pass
_ = FunctionSignature(foo)
def test_single_in_single_out(self):
+
def foo(image):
return image * 2
@@ -55,6 +58,7 @@ def foo(image):
self.assertEqual(dres["img"], 4)
def test_multi_in_single_out(self):
+
def foo(image, label):
return image * label
@@ -86,6 +90,7 @@ def foo(image, label):
self.assertEqual(dres["lbl"], 3)
def test_default_arg_single_out(self):
+
def foo(a, b=2):
return a * b
@@ -98,6 +103,7 @@ def foo(a, b=2):
dres = adaptor(foo, "c")(d)
def test_multi_out(self):
+
def foo(a, b):
return a * b, a / b
@@ -107,6 +113,7 @@ def foo(a, b):
self.assertEqual(dres["d"], 3 / 4)
def test_dict_out(self):
+
def foo(a):
return {"a": a * 2}
@@ -120,7 +127,9 @@ def foo(a):
class TestApplyAlias(unittest.TestCase):
+
def test_apply_alias(self):
+
def foo(d):
d["x"] *= 2
return d
@@ -131,7 +140,9 @@ def foo(d):
class TestToKwargs(unittest.TestCase):
+
def test_to_kwargs(self):
+
def foo(**kwargs):
results = {k: v * 2 for k, v in kwargs.items()}
return results
diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py
index cd33f98fd5..199fe071e3 100644
--- a/tests/test_add_coordinate_channels.py
+++ b/tests/test_add_coordinate_channels.py
@@ -29,6 +29,7 @@
class TestAddCoordinateChannels(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, input, expected_shape):
result = AddCoordinateChannels(**input_param)(input)
diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py
index f5784928fd..c00240c2d5 100644
--- a/tests/test_add_coordinate_channelsd.py
+++ b/tests/test_add_coordinate_channelsd.py
@@ -42,6 +42,7 @@
class TestAddCoordinateChannels(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, input, expected_shape):
result = AddCoordinateChannelsd(**input_param)(input)["img"]
diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py
index 140caa34ba..c453322d6b 100644
--- a/tests/test_add_extreme_points_channel.py
+++ b/tests/test_add_extreme_points_channel.py
@@ -69,6 +69,7 @@
class TestAddExtremePointsChannel(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_correct_results(self, input_data, expected):
add_extreme_points_channel = AddExtremePointsChannel()
diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py
index 5640e696fc..026f71200a 100644
--- a/tests/test_add_extreme_points_channeld.py
+++ b/tests/test_add_extreme_points_channeld.py
@@ -64,6 +64,7 @@
class TestAddExtremePointsChanneld(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_correct_results(self, input_data, expected):
add_extreme_points_channel = AddExtremePointsChanneld(
diff --git a/tests/test_adjust_contrast.py b/tests/test_adjust_contrast.py
index 9fa0247115..2236056558 100644
--- a/tests/test_adjust_contrast.py
+++ b/tests/test_adjust_contrast.py
@@ -30,6 +30,7 @@
class TestAdjustContrast(NumpyImageTestCase2D):
+
@parameterized.expand(TESTS)
def test_correct_results(self, gamma, invert_image, retain_stats):
adjuster = AdjustContrast(gamma=gamma, invert_image=invert_image, retain_stats=retain_stats)
diff --git a/tests/test_adjust_contrastd.py b/tests/test_adjust_contrastd.py
index 4a671ef7be..38eb001226 100644
--- a/tests/test_adjust_contrastd.py
+++ b/tests/test_adjust_contrastd.py
@@ -30,6 +30,7 @@
class TestAdjustContrastd(NumpyImageTestCase2D):
+
@parameterized.expand(TESTS)
def test_correct_results(self, gamma, invert_image, retain_stats):
adjuster = AdjustContrastd("img", gamma=gamma, invert_image=invert_image, retain_stats=retain_stats)
diff --git a/tests/test_adn.py b/tests/test_adn.py
index 27e23a08d3..327bf7b20c 100644
--- a/tests/test_adn.py
+++ b/tests/test_adn.py
@@ -59,6 +59,7 @@
class TestADN2D(TorchImageTestCase2D):
+
@parameterized.expand(TEST_CASES_2D)
def test_adn_2d(self, args):
adn = ADN(**args)
@@ -73,6 +74,7 @@ def test_no_input(self):
class TestADN3D(TorchImageTestCase3D):
+
@parameterized.expand(TEST_CASES_3D)
def test_adn_3d(self, args):
adn = ADN(**args)
diff --git a/tests/test_adversarial_loss.py b/tests/test_adversarial_loss.py
index 77880725ec..f7b9ae7eb0 100644
--- a/tests/test_adversarial_loss.py
+++ b/tests/test_adversarial_loss.py
@@ -39,6 +39,7 @@
class TestPatchAdversarialLoss(unittest.TestCase):
+
def get_input(self, shape, is_positive):
"""
Get tensor for the tests. The tensor is around (-1) or (+1), depending on
diff --git a/tests/test_affine.py b/tests/test_affine.py
index 9c2f4197a6..a08a22ae6f 100644
--- a/tests/test_affine.py
+++ b/tests/test_affine.py
@@ -167,6 +167,7 @@
class TestAffine(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_affine(self, input_param, input_data, expected_val):
input_copy = deepcopy(input_data["img"])
@@ -199,6 +200,7 @@ def test_affine(self, input_param, input_data, expected_val):
@unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.")
class TestAffineConsistency(unittest.TestCase):
+
@parameterized.expand([[7], [8], [9]])
def test_affine_resize(self, s):
"""s"""
diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py
index f3febbe0f3..2d89725bb7 100644
--- a/tests/test_affine_grid.py
+++ b/tests/test_affine_grid.py
@@ -135,6 +135,7 @@
class TestAffineGrid(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_affine_grid(self, input_param, input_data, expected_val):
g = AffineGrid(**input_param)
diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py
index 39dc609167..6ea036bce8 100644
--- a/tests/test_affine_transform.py
+++ b/tests/test_affine_transform.py
@@ -83,6 +83,7 @@
class TestNormTransform(unittest.TestCase):
+
@parameterized.expand(TEST_NORM_CASES)
def test_norm_xform(self, input_shape, align_corners, expected, zero_centered=False):
norm = normalize_transform(
@@ -107,6 +108,7 @@ def test_norm_xform(self, input_shape, align_corners, expected, zero_centered=Fa
class TestToNormAffine(unittest.TestCase):
+
@parameterized.expand(TEST_TO_NORM_AFFINE_CASES)
def test_to_norm_affine(self, affine, src_size, dst_size, align_corners, expected, zero_centered=False):
affine = torch.as_tensor(affine, device=torch.device("cpu:0"), dtype=torch.float32)
@@ -130,6 +132,7 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):
class TestAffineTransform(unittest.TestCase):
+
def test_affine_shift(self):
affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]])
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
diff --git a/tests/test_affined.py b/tests/test_affined.py
index a35b35758a..94903ff8c7 100644
--- a/tests/test_affined.py
+++ b/tests/test_affined.py
@@ -168,6 +168,7 @@
class TestAffined(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_affine(self, input_param, input_data, expected_val):
input_copy = deepcopy(input_data)
diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py
index 5707cf0452..99a177f395 100644
--- a/tests/test_ahnet.py
+++ b/tests/test_ahnet.py
@@ -126,6 +126,7 @@
class TestFCN(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_FCN_1, TEST_CASE_FCN_2, TEST_CASE_FCN_3])
@skip_if_quick
def test_fcn_shape(self, input_param, input_shape, expected_shape):
@@ -136,6 +137,7 @@ def test_fcn_shape(self, input_param, input_shape, expected_shape):
class TestFCNWithPretrain(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_FCN_WITH_PRETRAIN_1, TEST_CASE_FCN_WITH_PRETRAIN_2])
@skip_if_quick
def test_fcn_shape(self, input_param, input_shape, expected_shape):
@@ -146,6 +148,7 @@ def test_fcn_shape(self, input_param, input_shape, expected_shape):
class TestMCFCN(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_MCFCN_1, TEST_CASE_MCFCN_2, TEST_CASE_MCFCN_3])
def test_mcfcn_shape(self, input_param, input_shape, expected_shape):
net = MCFCN(**input_param).to(device)
@@ -155,6 +158,7 @@ def test_mcfcn_shape(self, input_param, input_shape, expected_shape):
class TestMCFCNWithPretrain(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_MCFCN_WITH_PRETRAIN_1, TEST_CASE_MCFCN_WITH_PRETRAIN_2])
def test_mcfcn_shape(self, input_param, input_shape, expected_shape):
net = test_pretrained_networks(MCFCN, input_param, device)
@@ -164,6 +168,7 @@ def test_mcfcn_shape(self, input_param, input_shape, expected_shape):
class TestAHNET(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_AHNET_2D_1, TEST_CASE_AHNET_2D_2, TEST_CASE_AHNET_2D_3])
def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape):
net = AHNet(**input_param).to(device)
@@ -192,6 +197,7 @@ def test_script(self):
class TestAHNETWithPretrain(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, TEST_CASE_AHNET_3D_WITH_PRETRAIN_3]
)
diff --git a/tests/test_anchor_box.py b/tests/test_anchor_box.py
index c29296e8ae..301ce78361 100644
--- a/tests/test_anchor_box.py
+++ b/tests/test_anchor_box.py
@@ -42,6 +42,7 @@
@SkipIfBeforePyTorchVersion((1, 11))
@unittest.skipUnless(has_torchvision, "Requires torchvision")
class TestAnchorGenerator(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_2D)
def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils")
diff --git a/tests/test_apply.py b/tests/test_apply.py
index 4784d46413..ca37e945ba 100644
--- a/tests/test_apply.py
+++ b/tests/test_apply.py
@@ -39,6 +39,7 @@ def single_2d_transform_cases():
class TestApply(unittest.TestCase):
+
def _test_apply_impl(self, tensor, pending_transforms, expected_shape):
result = apply_pending(tensor, pending_transforms)
self.assertListEqual(result[1], pending_transforms)
diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py
index 0de77bfb4d..e8db6da4b9 100644
--- a/tests/test_apply_filter.py
+++ b/tests/test_apply_filter.py
@@ -20,6 +20,7 @@
class ApplyFilterTestCase(unittest.TestCase):
+
def test_1d(self):
a = torch.tensor([[list(range(10))]], dtype=torch.float)
out = apply_filter(a, torch.tensor([-1, 0, 1]), stride=1)
diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py
index a3b78fc6e0..efc014a267 100644
--- a/tests/test_arraydataset.py
+++ b/tests/test_arraydataset.py
@@ -40,6 +40,7 @@
class TestCompose(Compose):
+
def __call__(self, input_, lazy):
img = self.transforms[0](input_)
metadata = img.meta
@@ -77,6 +78,7 @@ def __call__(self, input_, lazy):
class TestArrayDataset(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, img_transform, label_transform, indices, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)).astype(float), np.eye(4))
diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py
index 8f88fb2928..51e1a5c0fd 100644
--- a/tests/test_as_channel_last.py
+++ b/tests/test_as_channel_last.py
@@ -27,6 +27,7 @@
class TestAsChannelLast(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, in_type, input_param, expected_shape):
test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4]))
diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py
index 16086b769c..aa51ab6056 100644
--- a/tests/test_as_channel_lastd.py
+++ b/tests/test_as_channel_lastd.py
@@ -27,6 +27,7 @@
class TestAsChannelLastd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, in_type, input_param, expected_shape):
test_data = {
diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py
index 2802c7d9ff..bf59752920 100644
--- a/tests/test_as_discrete.py
+++ b/tests/test_as_discrete.py
@@ -65,6 +65,7 @@
class TestAsDiscrete(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_value_shape(self, input_param, img, out, expected_shape):
result = AsDiscrete(**input_param)(img)
diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py
index ec394fc3af..ed1b3c5b3e 100644
--- a/tests/test_as_discreted.py
+++ b/tests/test_as_discreted.py
@@ -68,6 +68,7 @@
class TestAsDiscreted(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_value_shape(self, input_param, test_input, output, expected_shape):
result = AsDiscreted(**input_param)(test_input)
diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py
index a614497bc9..6133d4839d 100644
--- a/tests/test_atss_box_matcher.py
+++ b/tests/test_atss_box_matcher.py
@@ -33,6 +33,7 @@
class TestATSS(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_atss(self, input_param, boxes, anchors, num_anchors_per_level, num_anchors_per_loc, expected_matches):
matcher = ATSSMatcher(**input_param, debug=True)
diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py
index d5c67cee38..83f6cabc5e 100644
--- a/tests/test_attentionunet.py
+++ b/tests/test_attentionunet.py
@@ -20,6 +20,7 @@
class TestAttentionUnet(unittest.TestCase):
+
def test_attention_block(self):
for dims in [2, 3]:
block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6)
diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py
index 5964ddd6e9..e2097679e2 100644
--- a/tests/test_auto3dseg.py
+++ b/tests/test_auto3dseg.py
@@ -165,6 +165,7 @@ def __call__(self, data):
class TestDataAnalyzer(unittest.TestCase):
+
def setUp(self):
self.test_dir = tempfile.TemporaryDirectory()
work_dir = self.test_dir.name
diff --git a/tests/test_auto3dseg_bundlegen.py b/tests/test_auto3dseg_bundlegen.py
index 1d2d6611bb..e7bf6820bc 100644
--- a/tests/test_auto3dseg_bundlegen.py
+++ b/tests/test_auto3dseg_bundlegen.py
@@ -107,6 +107,7 @@ def run_auto3dseg_before_bundlegen(test_path, work_dir):
@SkipIfBeforePyTorchVersion((1, 11, 1))
@skip_if_quick
class TestBundleGen(unittest.TestCase):
+
def setUp(self) -> None:
set_determinism(0)
self.test_dir = tempfile.TemporaryDirectory()
diff --git a/tests/test_auto3dseg_ensemble.py b/tests/test_auto3dseg_ensemble.py
index 367f66581c..7ac553cc0c 100644
--- a/tests/test_auto3dseg_ensemble.py
+++ b/tests/test_auto3dseg_ensemble.py
@@ -112,6 +112,7 @@ def create_sim_data(dataroot, sim_datalist, sim_dim, **kwargs):
@SkipIfBeforePyTorchVersion((1, 11, 1))
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestEnsembleBuilder(unittest.TestCase):
+
def setUp(self) -> None:
set_determinism(0)
self.test_dir = tempfile.TemporaryDirectory()
diff --git a/tests/test_auto3dseg_hpo.py b/tests/test_auto3dseg_hpo.py
index 0441116dc9..34d00336ec 100644
--- a/tests/test_auto3dseg_hpo.py
+++ b/tests/test_auto3dseg_hpo.py
@@ -79,6 +79,7 @@ def skip_if_no_optuna(obj):
@SkipIfBeforePyTorchVersion((1, 11, 1))
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestHPO(unittest.TestCase):
+
def setUp(self) -> None:
self.test_dir = tempfile.TemporaryDirectory()
test_path = self.test_dir.name
@@ -154,6 +155,7 @@ def test_run_optuna(self) -> None:
algo = algo_dict[AlgoKeys.ALGO]
class OptunaGenLearningRate(OptunaGen):
+
def get_hyperparameters(self):
return {"learning_rate": self.trial.suggest_float("learning_rate", 0.00001, 0.1)}
diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py
index 485049c2d1..6408f6a6d0 100644
--- a/tests/test_autoencoder.py
+++ b/tests/test_autoencoder.py
@@ -74,6 +74,7 @@
class TestAutoEncoder(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = AutoEncoder(**input_param).to(device)
diff --git a/tests/test_avg_merger.py b/tests/test_avg_merger.py
index adef2a759a..7995d63271 100644
--- a/tests/test_avg_merger.py
+++ b/tests/test_avg_merger.py
@@ -137,6 +137,7 @@
class AvgMergerTests(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_0_DEFAULT_DTYPE,
diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py
index 23e19dd536..770750851f 100644
--- a/tests/test_basic_unet.py
+++ b/tests/test_basic_unet.py
@@ -83,6 +83,7 @@
class TestBasicUNET(unittest.TestCase):
+
@parameterized.expand(CASES_1D + CASES_2D + CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_basic_unetplusplus.py b/tests/test_basic_unetplusplus.py
index 19ed5977fd..6438b5e0d4 100644
--- a/tests/test_basic_unetplusplus.py
+++ b/tests/test_basic_unetplusplus.py
@@ -83,6 +83,7 @@
class TestBasicUNETPlusPlus(unittest.TestCase):
+
@parameterized.expand(CASES_1D + CASES_2D + CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py
index f29d4f256b..2e8ab32dbd 100644
--- a/tests/test_bending_energy.py
+++ b/tests/test_bending_energy.py
@@ -50,6 +50,7 @@
class TestBendingEnergy(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = BendingEnergyLoss(**input_param).forward(**input_data)
diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py
index da30d5d7de..e8a55e1f76 100644
--- a/tests/test_bilateral_approx_cpu.py
+++ b/tests/test_bilateral_approx_cpu.py
@@ -365,6 +365,7 @@
@skip_if_no_cpp_extension
class BilateralFilterTestCaseCpuApprox(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cpu_approx(self, test_case_description, sigmas, input, expected):
# Params to determine the implementation to test
diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py
index b9be7d9ccf..4ad15d9646 100644
--- a/tests/test_bilateral_approx_cuda.py
+++ b/tests/test_bilateral_approx_cuda.py
@@ -366,6 +366,7 @@
@skip_if_no_cuda
@skip_if_no_cpp_extension
class BilateralFilterTestCaseCudaApprox(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cuda_approx(self, test_case_description, sigmas, input, expected):
# Skip this test
diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py
index 1a68dc8b4e..e13ede5bfd 100644
--- a/tests/test_bilateral_precise.py
+++ b/tests/test_bilateral_precise.py
@@ -366,6 +366,7 @@
@skip_if_no_cpp_extension
@skip_if_quick
class BilateralFilterTestCaseCpuPrecise(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cpu_precise(self, test_case_description, sigmas, input, expected):
# Params to determine the implementation to test
@@ -399,6 +400,7 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expec
@skip_if_no_cuda
@skip_if_no_cpp_extension
class BilateralFilterTestCaseCudaPrecise(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cuda_precise(self, test_case_description, sigmas, input, expected):
# Skip this test
diff --git a/tests/test_blend_images.py b/tests/test_blend_images.py
index 9814a5a3f8..700ae1fe58 100644
--- a/tests/test_blend_images.py
+++ b/tests/test_blend_images.py
@@ -44,6 +44,7 @@ def get_alpha(img):
@skipUnless(has_matplotlib, "Matplotlib required")
class TestBlendImages(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_blend(self, image, label, alpha):
blended = blend_images(image, label, alpha)
diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py
index b9c232e2d2..b879fa6093 100644
--- a/tests/test_bounding_rect.py
+++ b/tests/test_bounding_rect.py
@@ -28,6 +28,7 @@
class TestBoundingRect(unittest.TestCase):
+
def setUp(self):
monai.utils.set_determinism(1)
diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py
index 248a0a8e47..96435036b1 100644
--- a/tests/test_bounding_rectd.py
+++ b/tests/test_bounding_rectd.py
@@ -28,6 +28,7 @@
class TestBoundingRectD(unittest.TestCase):
+
def setUp(self):
monai.utils.set_determinism(1)
diff --git a/tests/test_box_coder.py b/tests/test_box_coder.py
index 5835341139..75ff650d6c 100644
--- a/tests/test_box_coder.py
+++ b/tests/test_box_coder.py
@@ -21,6 +21,7 @@
class TestBoxTransform(unittest.TestCase):
+
def test_value(self):
box_coder = BoxCoder(weights=[1, 1, 1, 1, 1, 1])
test_dtype = [torch.float32, torch.float16]
diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py
index e114f8869f..e99f95fa32 100644
--- a/tests/test_box_transform.py
+++ b/tests/test_box_transform.py
@@ -79,6 +79,7 @@
class TestBoxTransform(unittest.TestCase):
+
@parameterized.expand(TESTS_2D_mask)
def test_value_2d_mask(self, mask, expected_box_label):
box_label = convert_mask_to_box(mask)
diff --git a/tests/test_box_utils.py b/tests/test_box_utils.py
index c4fefb5a98..3c05efe0d0 100644
--- a/tests/test_box_utils.py
+++ b/tests/test_box_utils.py
@@ -140,6 +140,7 @@
class TestCreateBoxList(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_data, mode2, expected_box, expected_area):
expected_box = convert_data_type(expected_box, dtype=np.float32)[0]
diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py
index d9b3bedab2..8f376a06d5 100644
--- a/tests/test_bundle_ckpt_export.py
+++ b/tests/test_bundle_ckpt_export.py
@@ -32,6 +32,7 @@
@skip_if_windows
class TestCKPTExport(unittest.TestCase):
+
def setUp(self):
self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
if not self.device:
diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py
index fa96c6f28d..89fbe5e8b2 100644
--- a/tests/test_bundle_download.py
+++ b/tests/test_bundle_download.py
@@ -93,6 +93,7 @@
class TestDownload(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@skip_if_quick
def test_github_download_bundle(self, bundle_name, version):
@@ -192,6 +193,7 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve
@skip_if_no_cuda
class TestLoad(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_7])
@skip_if_quick
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
@@ -336,6 +338,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device,
class TestDownloadLargefiles(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_10])
@skip_if_quick
def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val):
diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py
index 88bfed758a..605b3945bb 100644
--- a/tests/test_bundle_get_data.py
+++ b/tests/test_bundle_get_data.py
@@ -45,6 +45,7 @@
@skip_if_windows
@SkipIfNoModule("requests")
class TestGetBundleData(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
@skip_if_quick
def test_get_all_bundles_list(self, params):
diff --git a/tests/test_bundle_init_bundle.py b/tests/test_bundle_init_bundle.py
index 08f921da01..eb831093d5 100644
--- a/tests/test_bundle_init_bundle.py
+++ b/tests/test_bundle_init_bundle.py
@@ -23,6 +23,7 @@
@skip_if_windows
class TestBundleInit(unittest.TestCase):
+
def test_bundle(self):
with tempfile.TemporaryDirectory() as tempdir:
net = UNet(2, 1, 1, [4, 8], [2])
diff --git a/tests/test_bundle_onnx_export.py b/tests/test_bundle_onnx_export.py
index ffd5fa636d..ee22d7caef 100644
--- a/tests/test_bundle_onnx_export.py
+++ b/tests/test_bundle_onnx_export.py
@@ -29,6 +29,7 @@
@SkipIfNoModule("onnx")
@SkipIfBeforePyTorchVersion((1, 10))
class TestONNXExport(unittest.TestCase):
+
def setUp(self):
self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
if not self.device:
diff --git a/tests/test_bundle_push_to_hf_hub.py b/tests/test_bundle_push_to_hf_hub.py
index 375c5d81e8..39368c6f40 100644
--- a/tests/test_bundle_push_to_hf_hub.py
+++ b/tests/test_bundle_push_to_hf_hub.py
@@ -28,6 +28,7 @@
class TestPushToHuggingFaceHub(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
@skip_if_quick
@skipUnless(has_huggingface_hub, "Requires `huggingface_hub` package.")
diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py
index 72743f5fcb..47034852ef 100644
--- a/tests/test_bundle_trt_export.py
+++ b/tests/test_bundle_trt_export.py
@@ -48,6 +48,7 @@
@skip_if_no_cuda
@skip_if_quick
class TestTRTExport(unittest.TestCase):
+
def setUp(self):
self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
if not self.device:
diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py
index 181c08475c..47c534f3b6 100644
--- a/tests/test_bundle_utils.py
+++ b/tests/test_bundle_utils.py
@@ -51,6 +51,7 @@
@skip_if_windows
class TestLoadBundleConfig(unittest.TestCase):
+
def setUp(self):
self.bundle_dir = tempfile.TemporaryDirectory()
self.dir_name = os.path.join(self.bundle_dir.name, "TestBundle")
@@ -134,6 +135,7 @@ def test_load_config_ts(self):
class TestPPrintEdges(unittest.TestCase):
+
def test_str(self):
self.assertEqual(pprint_edges("", 0), "''")
self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}")
diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py
index 0701e905b9..f6c2192621 100644
--- a/tests/test_bundle_verify_metadata.py
+++ b/tests/test_bundle_verify_metadata.py
@@ -28,6 +28,7 @@
@skip_if_windows
class TestVerifyMetaData(unittest.TestCase):
+
def setUp(self):
self.config = testing_data_config("configs", "test_meta_file")
download_url_or_skip_test(
diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py
index 6f516fdd48..f55fdd597b 100644
--- a/tests/test_bundle_verify_net.py
+++ b/tests/test_bundle_verify_net.py
@@ -28,6 +28,7 @@
@skip_if_windows
class TestVerifyNetwork(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_verify(self, meta_file, config_file):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py
index 4291eedf3f..f7da37acef 100644
--- a/tests/test_bundle_workflow.py
+++ b/tests/test_bundle_workflow.py
@@ -37,6 +37,7 @@
class TestBundleWorkflow(unittest.TestCase):
+
def setUp(self):
self.data_dir = tempfile.mkdtemp()
self.expected_shape = (128, 128, 128)
diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py
index dcae5fdce1..dbb1b8f8f1 100644
--- a/tests/test_cachedataset.py
+++ b/tests/test_cachedataset.py
@@ -39,6 +39,7 @@
class TestCacheDataset(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py
index c3fc2cc362..6a01a82512 100644
--- a/tests/test_cachedataset_parallel.py
+++ b/tests/test_cachedataset_parallel.py
@@ -30,6 +30,7 @@
class TestCacheDatasetParallel(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, num_workers, dataset_size, transform):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4))
diff --git a/tests/test_cachedataset_persistent_workers.py b/tests/test_cachedataset_persistent_workers.py
index e60862238d..78092906c6 100644
--- a/tests/test_cachedataset_persistent_workers.py
+++ b/tests/test_cachedataset_persistent_workers.py
@@ -18,6 +18,7 @@
class TestTransformsWCacheDatasetAndPersistentWorkers(unittest.TestCase):
+
def test_duplicate_transforms(self):
data = [{"img": create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0)[0]} for _ in range(2)]
diff --git a/tests/test_cachentransdataset.py b/tests/test_cachentransdataset.py
index d50fe4f8dd..90e86c2eb0 100644
--- a/tests/test_cachentransdataset.py
+++ b/tests/test_cachentransdataset.py
@@ -34,6 +34,7 @@
class TestCacheNTransDataset(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_n_trans(self, transform, expected_shape):
data_array = np.random.randint(0, 2, size=[128, 128, 128]).astype(float)
diff --git a/tests/test_call_dist.py b/tests/test_call_dist.py
index 0621824b65..503cb5e792 100644
--- a/tests/test_call_dist.py
+++ b/tests/test_call_dist.py
@@ -17,6 +17,7 @@
class DistributedCallTest(DistTestCase):
+
def test_constructor(self):
with self.assertRaises(ValueError):
DistCall(nnodes=1, nproc_per_node=0)
diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py
index 6dd994120c..035260804e 100644
--- a/tests/test_cast_to_type.py
+++ b/tests/test_cast_to_type.py
@@ -37,6 +37,7 @@
class TestCastToType(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_type(self, out_dtype, input_data, expected_type):
result = CastToType(dtype=out_dtype)(input_data)
diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py
index 687deeda1d..81e17117a9 100644
--- a/tests/test_cast_to_typed.py
+++ b/tests/test_cast_to_typed.py
@@ -53,6 +53,7 @@
class TestCastToTyped(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type(self, input_param, input_data, expected_type):
result = CastToTyped(**input_param)(input_data)
diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py
index 2d8c57fd68..77dd172378 100644
--- a/tests/test_channel_pad.py
+++ b/tests/test_channel_pad.py
@@ -34,6 +34,7 @@
class TestChannelPad(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = ChannelPad(**input_param)
diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py
index bb3d0ff12e..263c18703c 100644
--- a/tests/test_check_hash.py
+++ b/tests/test_check_hash.py
@@ -32,6 +32,7 @@
class TestCheckMD5(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_result(self, md5_value, t, expected_result):
test_image = np.ones((5, 5, 3))
diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py
index efbe5a95fb..2b5c17a1ec 100644
--- a/tests/test_check_missing_files.py
+++ b/tests/test_check_missing_files.py
@@ -23,6 +23,7 @@
class TestCheckMissingFiles(unittest.TestCase):
+
def test_content(self):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py
index e7dd7abfe5..a7377dac16 100644
--- a/tests/test_classes_to_indices.py
+++ b/tests/test_classes_to_indices.py
@@ -82,6 +82,7 @@
class TestClassesToIndices(unittest.TestCase):
+
@parameterized.expand(TESTS_CASES)
def test_value(self, input_args, label, image, expected_indices):
indices = ClassesToIndices(**input_args)(label, image)
diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py
index 7a34cc06b4..dead1ae753 100644
--- a/tests/test_classes_to_indicesd.py
+++ b/tests/test_classes_to_indicesd.py
@@ -97,6 +97,7 @@
class TestClassesToIndicesd(unittest.TestCase):
+
@parameterized.expand(TESTS_CASES)
def test_value(self, input_args, input_data, expected_indices):
result = ClassesToIndicesd(**input_args)(input_data)
diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py
index 071bd20d6c..14d3575e3b 100644
--- a/tests/test_cldice_loss.py
+++ b/tests/test_cldice_loss.py
@@ -23,6 +23,7 @@
class TestclDiceLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_result(self, y_pred_data, expected_val):
loss = SoftclDiceLoss()
diff --git a/tests/test_complex_utils.py b/tests/test_complex_utils.py
index 77eaa924a2..fdcee4babe 100644
--- a/tests/test_complex_utils.py
+++ b/tests/test_complex_utils.py
@@ -51,6 +51,7 @@
class TestMRIUtils(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_to_tensor_complex(self, test_data, expected_shape):
result = convert_to_tensor_complex(test_data)
diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py
index 3b54a13706..9378fc159d 100644
--- a/tests/test_component_locator.py
+++ b/tests/test_component_locator.py
@@ -21,6 +21,7 @@
class TestComponentLocator(unittest.TestCase):
+
def test_locate(self):
locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"])
# test init mapping table and get the module path of component
diff --git a/tests/test_component_store.py b/tests/test_component_store.py
index 614f387754..424eceb3d1 100644
--- a/tests/test_component_store.py
+++ b/tests/test_component_store.py
@@ -17,6 +17,7 @@
class TestComponentStore(unittest.TestCase):
+
def setUp(self):
self.cs = ComponentStore("TestStore", "I am a test store, please ignore")
diff --git a/tests/test_compose.py b/tests/test_compose.py
index a1952b102f..309767833b 100644
--- a/tests/test_compose.py
+++ b/tests/test_compose.py
@@ -39,6 +39,7 @@ def data_from_keys(keys, h, w):
class _RandXform(Randomizable):
+
def randomize(self):
self.val = self.R.random_sample()
@@ -48,12 +49,14 @@ def __call__(self, __unused):
class TestCompose(unittest.TestCase):
+
def test_empty_compose(self):
c = mt.Compose()
i = 1
self.assertEqual(c(i), 1)
def test_non_dict_compose(self):
+
def a(i):
return i + "a"
@@ -64,6 +67,7 @@ def b(i):
self.assertEqual(c(""), "abab")
def test_dict_compose(self):
+
def a(d):
d = dict(d)
d["a"] += 1
@@ -82,6 +86,7 @@ def b(d):
self.assertDictEqual(execute_compose(data, transforms), expected)
def test_list_dict_compose(self):
+
def a(d): # transform to handle dict data
d = dict(d)
d["a"] += 1
@@ -109,6 +114,7 @@ def c(d): # transform to handle dict data
self.assertDictEqual(item, expected)
def test_non_dict_compose_with_unpack(self):
+
def a(i, i2):
return i + "a", i2 + "a2"
@@ -122,6 +128,7 @@ def b(i, i2):
self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected)
def test_list_non_dict_compose_with_unpack(self):
+
def a(i, i2):
return i + "a", i2 + "a2"
@@ -135,6 +142,7 @@ def b(i, i2):
self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)
def test_list_dict_compose_no_map(self):
+
def a(d): # transform to handle dict data
d = dict(d)
d["a"] += 1
@@ -163,6 +171,7 @@ def c(d): # transform to handle dict data
self.assertDictEqual(item, expected)
def test_random_compose(self):
+
class _Acc(Randomizable):
self.rand = 0.0
@@ -182,7 +191,9 @@ def __call__(self, data):
self.assertAlmostEqual(c(1), 1.90734751)
def test_randomize_warn(self):
+
class _RandomClass(Randomizable):
+
def randomize(self, foo1, foo2):
pass
@@ -267,6 +278,7 @@ def test_backwards_compatible_imports(self):
class TestComposeExecute(unittest.TestCase):
+
@parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)
def test_compose_execute_equivalence(self, keys, pipeline):
data = data_from_keys(keys, 12, 16)
@@ -657,8 +669,10 @@ def test_compose_lazy_on_call_with_logging(self, compose_type, pipeline, lazy_on
class TestOps:
+
@staticmethod
def concat(value):
+
def _inner(data):
return data + value
@@ -666,6 +680,7 @@ def _inner(data):
@staticmethod
def concatd(value):
+
def _inner(data):
return {k: v + value for k, v in data.items()}
@@ -673,6 +688,7 @@ def _inner(data):
@staticmethod
def concata(value):
+
def _inner(data1, data2):
return data1 + value, data2 + value
@@ -688,6 +704,7 @@ def _inner(data1, data2):
class TestComposeExecuteWithFlags(unittest.TestCase):
+
@parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES)
def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
expected = mt.Compose(pipeline, **flags)(data)
@@ -711,6 +728,7 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
class TestComposeCallableInput(unittest.TestCase):
+
def test_value_error_when_not_sequence(self):
data = torch.tensor(np.random.randn(1, 5, 5))
diff --git a/tests/test_compose_get_number_conversions.py b/tests/test_compose_get_number_conversions.py
index 664558d9cd..2623bab69c 100644
--- a/tests/test_compose_get_number_conversions.py
+++ b/tests/test_compose_get_number_conversions.py
@@ -38,6 +38,7 @@ def _apply(x, fn):
class Load(Transform):
+
def __init__(self, as_tensor):
self.fn = lambda _: PT_ARR if as_tensor else NP_ARR
@@ -46,26 +47,31 @@ def __call__(self, x):
class N(Transform):
+
def __call__(self, x):
return _apply(x, convert_to_numpy)
class T(Transform):
+
def __call__(self, x):
return _apply(x, convert_to_tensor)
class NT(Transform):
+
def __call__(self, x):
return _apply(x, lambda x: x)
class TCPU(Transform):
+
def __call__(self, x):
return _apply(x, lambda x: convert_to_tensor(x).cpu())
class TGPU(Transform):
+
def __call__(self, x):
return _apply(x, lambda x: convert_to_tensor(x).cuda())
@@ -103,6 +109,7 @@ def __call__(self, x):
class TestComposeNumConversions(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_get_number_of_conversions(self, transforms, is_dict, input, expected):
input = input if not is_dict else {KEY: input, "Other": NP_ARR}
diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py
index e0a92aec67..248f16a7fe 100644
--- a/tests/test_compute_confusion_matrix.py
+++ b/tests/test_compute_confusion_matrix.py
@@ -220,6 +220,7 @@
class TestConfusionMatrix(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_CONFUSION_MATRIX])
def test_value(self, input_data, expected_value):
# include or ignore background
diff --git a/tests/test_compute_f_beta.py b/tests/test_compute_f_beta.py
index c8ed5aa887..85997577cf 100644
--- a/tests/test_compute_f_beta.py
+++ b/tests/test_compute_f_beta.py
@@ -23,6 +23,7 @@
class TestFBetaScore(unittest.TestCase):
+
def test_expecting_success_and_device(self):
metric = FBetaScore()
y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device)
diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py
index 1c7c3273fe..bd867f5296 100644
--- a/tests/test_compute_fid_metric.py
+++ b/tests/test_compute_fid_metric.py
@@ -24,6 +24,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy")
class TestFIDMetric(unittest.TestCase):
+
def test_results(self):
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py
index 0a48dc099a..4dc0507366 100644
--- a/tests/test_compute_froc.py
+++ b/tests/test_compute_froc.py
@@ -111,6 +111,7 @@
class TestComputeFpTp(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_data, expected_fp, expected_tp, expected_num):
fp_probs, tp_probs, num_tumors = compute_fp_tp_probs(**input_data)
@@ -120,6 +121,7 @@ def test_value(self, input_data, expected_fp, expected_tp, expected_num):
class TestComputeFpTpNd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_ND_1, TEST_CASE_ND_2])
def test_value(self, input_data, expected_fp, expected_tp, expected_num):
fp_probs, tp_probs, num_tumors = compute_fp_tp_probs_nd(**input_data)
@@ -129,6 +131,7 @@ def test_value(self, input_data, expected_fp, expected_tp, expected_num):
class TestComputeFrocScore(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
def test_value(self, input_data, thresholds, expected_score):
fps_per_image, total_sensitivity = compute_froc_curve_data(**input_data)
diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py
index ab3d012c97..e04444e988 100644
--- a/tests/test_compute_generalized_dice.py
+++ b/tests/test_compute_generalized_dice.py
@@ -119,6 +119,7 @@
class TestComputeGeneralizedDiceScore(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_device(self, input_data, _expected_value):
result = compute_generalized_dice(**input_data)
diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py
index 50598cb57b..bbd5230f04 100644
--- a/tests/test_compute_ho_ver_maps.py
+++ b/tests/test_compute_ho_ver_maps.py
@@ -62,6 +62,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class ComputeHoVerMapsTests(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):
input_image = in_type(mask)
diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py
index 27bb57988c..7b5ac0d9d7 100644
--- a/tests/test_compute_ho_ver_maps_d.py
+++ b/tests/test_compute_ho_ver_maps_d.py
@@ -63,6 +63,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class ComputeHoVerMapsDictTests(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):
hv_key = list(hv_mask.keys())[0]
diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py
index 46e1d67b1b..aae15483b5 100644
--- a/tests/test_compute_meandice.py
+++ b/tests/test_compute_meandice.py
@@ -252,6 +252,7 @@
class TestComputeMeanDice(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
def test_value(self, input_data, expected_value):
result = compute_dice(**input_data)
diff --git a/tests/test_compute_meaniou.py b/tests/test_compute_meaniou.py
index d39edaa6f3..0b7a2bbce2 100644
--- a/tests/test_compute_meaniou.py
+++ b/tests/test_compute_meaniou.py
@@ -187,6 +187,7 @@
class TestComputeMeanIoU(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
def test_value(self, input_data, expected_value):
result = compute_iou(**input_data)
diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py
index d1b69b3dfe..96b5cbc089 100644
--- a/tests/test_compute_mmd_metric.py
+++ b/tests/test_compute_mmd_metric.py
@@ -36,6 +36,7 @@
class TestMMDMetric(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_results(self, input_param, input_data, expected_val):
metric = MMDMetric(**input_param)
diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py
index 4ebc5b7935..3df8026c2b 100644
--- a/tests/test_compute_multiscalessim_metric.py
+++ b/tests/test_compute_multiscalessim_metric.py
@@ -20,6 +20,7 @@
class TestMultiScaleSSIMMetric(unittest.TestCase):
+
def test2d_gaussian(self):
set_determinism(0)
preds = torch.abs(torch.randn(1, 1, 64, 64))
diff --git a/tests/test_compute_panoptic_quality.py b/tests/test_compute_panoptic_quality.py
index a5858e91d1..a916ea32b2 100644
--- a/tests/test_compute_panoptic_quality.py
+++ b/tests/test_compute_panoptic_quality.py
@@ -92,6 +92,7 @@
@SkipIfNoModule("scipy.optimize")
class TestPanopticQualityMetric(unittest.TestCase):
+
@parameterized.expand([TEST_FUNC_CASE_1, TEST_FUNC_CASE_2, TEST_FUNC_CASE_3, TEST_FUNC_CASE_4])
def test_value(self, input_params, expected_value):
result = compute_panoptic_quality(**input_params)
diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py
index b0fde3afe9..a8b7f03e47 100644
--- a/tests/test_compute_regression_metrics.py
+++ b/tests/test_compute_regression_metrics.py
@@ -45,6 +45,7 @@ def psnrmetric_np(max_val, y_pred, y):
class TestRegressionMetrics(unittest.TestCase):
+
def test_shape_reduction(self):
set_determinism(seed=123)
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py
index 2f080c76cb..f2cb816db4 100644
--- a/tests/test_compute_roc_auc.py
+++ b/tests/test_compute_roc_auc.py
@@ -100,6 +100,7 @@
class TestComputeROCAUC(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_1,
diff --git a/tests/test_compute_variance.py b/tests/test_compute_variance.py
index 8eaac10a6c..486a1e9f6f 100644
--- a/tests/test_compute_variance.py
+++ b/tests/test_compute_variance.py
@@ -109,6 +109,7 @@
class TestComputeVariance(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, input_data, expected_value):
result = compute_variance(**input_data)
diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py
index 322a95d7df..64c5d6e255 100644
--- a/tests/test_concat_itemsd.py
+++ b/tests/test_concat_itemsd.py
@@ -22,6 +22,7 @@
class TestConcatItemsd(unittest.TestCase):
+
def test_tensor_values(self):
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0")
input_data = {
diff --git a/tests/test_config_item.py b/tests/test_config_item.py
index cb1e7ad552..72f54adf0a 100644
--- a/tests/test_config_item.py
+++ b/tests/test_config_item.py
@@ -52,6 +52,7 @@
class TestConfigItem(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_item(self, test_input, expected):
item = ConfigItem(config=test_input)
diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py
index 63254e7336..41d7aa7a4e 100644
--- a/tests/test_config_parser.py
+++ b/tests/test_config_parser.py
@@ -72,6 +72,7 @@ def case_pdb_inst(sarg=None):
class TestClass:
+
@staticmethod
def compute(a, b, func=lambda x, y: x + y):
return func(a, b)
@@ -126,6 +127,7 @@ def __call__(self, a, b):
class TestConfigParser(unittest.TestCase):
+
def test_config_content(self):
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
parser = ConfigParser(config=test_config)
diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py
index 4cafa0a905..21a9e76417 100644
--- a/tests/test_contrastive_loss.py
+++ b/tests/test_contrastive_loss.py
@@ -55,6 +55,7 @@
class TestContrastiveLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
contrastiveloss = ContrastiveLoss(**input_param)
diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py
index c3e4490ffe..b95539f4b7 100644
--- a/tests/test_convert_data_type.py
+++ b/tests/test_convert_data_type.py
@@ -77,6 +77,7 @@ class TestTensor(torch.Tensor):
class TestConvertDataType(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_convert_data_type(self, in_image, im_out, out_dtype, safe):
converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out), dtype=out_dtype, safe=safe)
diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py
index 78c3c90688..98bbea1ebf 100644
--- a/tests/test_convert_to_multi_channel.py
+++ b/tests/test_convert_to_multi_channel.py
@@ -48,6 +48,7 @@
class TestConvertToMultiChannel(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_type_shape(self, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClasses()(data)
diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py
index 351adddb13..e482770497 100644
--- a/tests/test_convert_to_multi_channeld.py
+++ b/tests/test_convert_to_multi_channeld.py
@@ -26,6 +26,7 @@
class TestConvertToMultiChanneld(unittest.TestCase):
+
@parameterized.expand([TEST_CASE])
def test_type_shape(self, keys, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
diff --git a/tests/test_convert_to_onnx.py b/tests/test_convert_to_onnx.py
index 7560c98703..398d260c52 100644
--- a/tests/test_convert_to_onnx.py
+++ b/tests/test_convert_to_onnx.py
@@ -36,6 +36,7 @@
@SkipIfBeforePyTorchVersion((1, 9))
@skip_if_quick
class TestConvertToOnnx(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_unet(self, device, use_trace, use_ort):
if use_ort:
diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py
index 0b8e9a8141..c78b8e78c0 100644
--- a/tests/test_convert_to_torchscript.py
+++ b/tests/test_convert_to_torchscript.py
@@ -22,6 +22,7 @@
class TestConvertToTorchScript(unittest.TestCase):
+
def test_value(self):
model = UNet(
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py
index 108ed66f31..5579539764 100644
--- a/tests/test_convert_to_trt.py
+++ b/tests/test_convert_to_trt.py
@@ -39,6 +39,7 @@
@skip_if_no_cuda
@skip_if_quick
class TestConvertToTRT(unittest.TestCase):
+
def setUp(self):
self.gpu_device = torch.cuda.current_device()
diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py
index 1311401f1d..77bc12770f 100644
--- a/tests/test_convolutions.py
+++ b/tests/test_convolutions.py
@@ -18,6 +18,7 @@
class TestConvolution2D(TorchImageTestCase2D):
+
def test_conv1(self):
conv = Convolution(2, self.input_channels, self.output_channels)
out = conv(self.imt)
@@ -69,6 +70,7 @@ def test_transpose2(self):
class TestConvolution3D(TorchImageTestCase3D):
+
def test_conv1(self):
conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.1, adn_ordering="DAN")
out = conv(self.imt)
@@ -126,6 +128,7 @@ def test_transpose2(self):
class TestResidualUnit2D(TorchImageTestCase2D):
+
def test_conv_only1(self):
conv = ResidualUnit(2, 1, self.output_channels)
out = conv(self.imt)
diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py
index ff4799a094..a78e08897b 100644
--- a/tests/test_copy_itemsd.py
+++ b/tests/test_copy_itemsd.py
@@ -32,6 +32,7 @@
class TestCopyItemsd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_numpy_values(self, keys, times, names):
input_data = {"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[3, 4], [4, 5]])}
diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py
index 2e7513b234..26b01d930a 100644
--- a/tests/test_copy_model_state.py
+++ b/tests/test_copy_model_state.py
@@ -22,6 +22,7 @@
class _TestModelOne(torch.nn.Module):
+
def __init__(self, n_n, n_m, n_class):
super().__init__()
self.layer = torch.nn.Linear(n_n, n_m)
@@ -34,6 +35,7 @@ def forward(self, x):
class _TestModelTwo(torch.nn.Module):
+
def __init__(self, n_n, n_m, n_d, n_class):
super().__init__()
self.layer = torch.nn.Linear(n_n, n_m)
@@ -55,6 +57,7 @@ def forward(self, x):
class TestModuleState(unittest.TestCase):
+
def tearDown(self):
set_determinism(None)
diff --git a/tests/test_correct_crop_centers.py b/tests/test_correct_crop_centers.py
index d2a95bf684..82b0b93b53 100644
--- a/tests/test_correct_crop_centers.py
+++ b/tests/test_correct_crop_centers.py
@@ -23,6 +23,7 @@
class TestCorrectCropCenters(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_torch(self, spatial_size, centers, label_spatial_shape):
result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape)
diff --git a/tests/test_create_cross_validation_datalist.py b/tests/test_create_cross_validation_datalist.py
index d05a94f59e..0e80be1cd0 100644
--- a/tests/test_create_cross_validation_datalist.py
+++ b/tests/test_create_cross_validation_datalist.py
@@ -20,6 +20,7 @@
class TestCreateCrossValidationDatalist(unittest.TestCase):
+
def test_content(self):
with tempfile.TemporaryDirectory() as tempdir:
datalist = []
diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py
index 2b5890a777..4910a10470 100644
--- a/tests/test_create_grid_and_affine.py
+++ b/tests/test_create_grid_and_affine.py
@@ -28,6 +28,7 @@
class TestCreateGrid(unittest.TestCase):
+
def test_create_grid(self):
with self.assertRaisesRegex(TypeError, ""):
create_grid(None)
@@ -168,6 +169,7 @@ def test_assert(func, params, expected):
class TestCreateAffine(unittest.TestCase):
+
def test_create_rotate(self):
with self.assertRaisesRegex(TypeError, ""):
create_rotate(2, None)
diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py
index e29a4d69eb..a7ae0ff2df 100644
--- a/tests/test_crf_cpu.py
+++ b/tests/test_crf_cpu.py
@@ -495,6 +495,7 @@
@skip_if_no_cpp_extension
class CRFTestCaseCpu(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test(self, test_case_description, params, input, features, expected):
# Create input tensors
diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py
index 8529e2e6de..d5329aab15 100644
--- a/tests/test_crf_cuda.py
+++ b/tests/test_crf_cuda.py
@@ -496,6 +496,7 @@
@skip_if_no_cpp_extension
@skip_if_no_cuda
class CRFTestCaseCuda(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test(self, test_case_description, params, input, features, expected):
# Create input tensors
diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py
index 4435b128ba..f63cb3e8b0 100644
--- a/tests/test_crop_foreground.py
+++ b/tests/test_crop_foreground.py
@@ -99,6 +99,7 @@
class TestCropForeground(unittest.TestCase):
+
@parameterized.expand(TEST_COORDS + TESTS)
def test_value(self, arguments, image, expected_data, _):
cropper = CropForeground(**arguments)
diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py
index 776776f6c5..92954aa81e 100644
--- a/tests/test_crop_foregroundd.py
+++ b/tests/test_crop_foregroundd.py
@@ -158,6 +158,7 @@
class TestCropForegroundd(unittest.TestCase):
+
@parameterized.expand(TEST_POSITION + TESTS)
def test_value(self, arguments, input_data, expected_data, _):
cropper = CropForegroundd(**arguments)
diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py
index de1122eeae..6d0f2319fb 100644
--- a/tests/test_cross_validation.py
+++ b/tests/test_cross_validation.py
@@ -21,6 +21,7 @@
class TestCrossValidation(unittest.TestCase):
+
@skip_if_quick
def test_values(self):
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_csv_dataset.py b/tests/test_csv_dataset.py
index 82a0f7afbd..71be4fdd22 100644
--- a/tests/test_csv_dataset.py
+++ b/tests/test_csv_dataset.py
@@ -23,6 +23,7 @@
class TestCSVDataset(unittest.TestCase):
+
def test_values(self):
with tempfile.TemporaryDirectory() as tempdir:
test_data1 = [
diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py
index 65a0a420a5..e06da0c41b 100644
--- a/tests/test_csv_iterable_dataset.py
+++ b/tests/test_csv_iterable_dataset.py
@@ -26,6 +26,7 @@
@skip_if_windows
class TestCSVIterableDataset(unittest.TestCase):
+
def test_values(self):
with tempfile.TemporaryDirectory() as tempdir:
test_data1 = [
diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py
index 833d1134cf..234b3f1057 100644
--- a/tests/test_csv_saver.py
+++ b/tests/test_csv_saver.py
@@ -23,6 +23,7 @@
class TestCSVSaver(unittest.TestCase):
+
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:
saver = CSVSaver(output_dir=tempdir, filename="predictions.csv", delimiter="\t")
diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py
index 6ebfd8bac7..d2dcc6aa5f 100644
--- a/tests/test_cucim_dict_transform.py
+++ b/tests/test_cucim_dict_transform.py
@@ -66,6 +66,7 @@
@unittest.skipUnless(HAS_CUPY, "CuPy is required.")
@unittest.skipUnless(has_cut, "cuCIM transforms are required.")
class TestCuCIMDict(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_COLOR_JITTER_1,
diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py
index 5884358a74..5f16c11589 100644
--- a/tests/test_cucim_transform.py
+++ b/tests/test_cucim_transform.py
@@ -66,6 +66,7 @@
@unittest.skipUnless(HAS_CUPY, "CuPy is required.")
@unittest.skipUnless(has_cut, "cuCIM transforms are required.")
class TestCuCIM(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_COLOR_JITTER_1,
diff --git a/tests/test_cumulative.py b/tests/test_cumulative.py
index 3377fa815c..d3b6ba094c 100644
--- a/tests/test_cumulative.py
+++ b/tests/test_cumulative.py
@@ -20,6 +20,7 @@
class TestCumulative(unittest.TestCase):
+
def test_single(self):
c = Cumulative()
c.extend([2, 3])
diff --git a/tests/test_cumulative_average.py b/tests/test_cumulative_average.py
index d815d9be77..624da2c7bb 100644
--- a/tests/test_cumulative_average.py
+++ b/tests/test_cumulative_average.py
@@ -32,6 +32,7 @@
class TestAverageMeter(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_1)
def test_value_all(self, data):
# test orig
diff --git a/tests/test_cumulative_average_dist.py b/tests/test_cumulative_average_dist.py
index 17f4164838..30c01c21ee 100644
--- a/tests/test_cumulative_average_dist.py
+++ b/tests/test_cumulative_average_dist.py
@@ -23,6 +23,7 @@
@SkipIfBeforePyTorchVersion((1, 8))
class DistributedCumulativeAverage(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_value(self):
rank = dist.get_rank()
diff --git a/tests/test_cv2_dist.py b/tests/test_cv2_dist.py
index edd2e1ec42..562c205763 100644
--- a/tests/test_cv2_dist.py
+++ b/tests/test_cv2_dist.py
@@ -42,6 +42,7 @@ def main_worker(rank, ngpus_per_node, port):
@skip_if_no_cuda
class TestCV2Dist(unittest.TestCase):
+
def test_cv2_cuda_ops(self):
print_config()
ngpus_per_node = torch.cuda.device_count()
diff --git a/tests/test_daf3d.py b/tests/test_daf3d.py
index 34e25cc6be..d20cb3cfd1 100644
--- a/tests/test_daf3d.py
+++ b/tests/test_daf3d.py
@@ -42,6 +42,7 @@
@unittest.skipUnless(has_tv, "torchvision not installed")
class TestDAF3D(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py
index 6ef51bef92..05453b0694 100644
--- a/tests/test_data_stats.py
+++ b/tests/test_data_stats.py
@@ -137,6 +137,7 @@
class TestDataStats(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_value(self, input_param, input_data, expected_print):
transform = DataStats(**input_param)
diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py
index 374bc815ac..ef88300c10 100644
--- a/tests/test_data_statsd.py
+++ b/tests/test_data_statsd.py
@@ -157,6 +157,7 @@
class TestDataStatsd(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
)
diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py
index 2ee69687a6..73e27799f7 100644
--- a/tests/test_dataloader.py
+++ b/tests/test_dataloader.py
@@ -29,6 +29,7 @@
class TestDataLoader(unittest.TestCase):
+
def test_values(self):
datalist = [
{"image": "spleen_19.nii.gz", "label": "spleen_label_19.nii.gz"},
@@ -59,6 +60,7 @@ def test_exception(self, datalist):
class _RandomDataset(torch.utils.data.Dataset, Randomizable):
+
def __getitem__(self, index):
return self.R.randint(0, 1000, (1,))
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index c7c2b77697..1398009c63 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -30,6 +30,7 @@
class TestDataset(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_shape(self, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py
index afccd129fe..166d888d9e 100644
--- a/tests/test_dataset_func.py
+++ b/tests/test_dataset_func.py
@@ -20,6 +20,7 @@
class TestDatasetFunc(unittest.TestCase):
+
def test_seg_values(self):
with tempfile.TemporaryDirectory() as tempdir:
# prepare test datalist file
diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py
index 87538425d5..21cc53de90 100644
--- a/tests/test_dataset_summary.py
+++ b/tests/test_dataset_summary.py
@@ -36,6 +36,7 @@ def test_collate(batch):
class TestDatasetSummary(unittest.TestCase):
+
def test_spacing_intensity(self):
set_determinism(seed=0)
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py
index 345cc487c5..d220cd9097 100644
--- a/tests/test_decathlondataset.py
+++ b/tests/test_decathlondataset.py
@@ -23,6 +23,7 @@
class TestDecathlonDataset(unittest.TestCase):
+
@skip_if_quick
def test_values(self):
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_decollate.py b/tests/test_decollate.py
index 26b9e7a4f4..92f7c89e28 100644
--- a/tests/test_decollate.py
+++ b/tests/test_decollate.py
@@ -81,6 +81,7 @@
class TestDeCollate(unittest.TestCase):
+
def setUp(self) -> None:
set_determinism(seed=0)
@@ -159,6 +160,7 @@ def test_decollation_list(self, *transforms):
class TestBasicDeCollate(unittest.TestCase):
+
@parameterized.expand(TEST_BASIC)
def test_decollation_examples(self, input_val, expected_out):
out = decollate_batch(input_val)
diff --git a/tests/test_deepedit_interaction.py b/tests/test_deepedit_interaction.py
index 5dcc6205f7..8baf4dc827 100644
--- a/tests/test_deepedit_interaction.py
+++ b/tests/test_deepedit_interaction.py
@@ -40,6 +40,7 @@ def add_one(engine):
class TestInteractions(unittest.TestCase):
+
def run_interaction(self, train):
label_names = {"spleen": 1, "background": 0}
np.random.seed(0)
diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py
index 7f4d4eee1e..18d6567fd7 100644
--- a/tests/test_deepedit_transforms.py
+++ b/tests/test_deepedit_transforms.py
@@ -209,6 +209,7 @@
class TestAddGuidanceFromPointsCustomd(unittest.TestCase):
+
@parameterized.expand([ADD_GUIDANCE_FROM_POINTS_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = AddGuidanceFromPointsDeepEditd(**arguments)
@@ -217,6 +218,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestAddGuidanceSignalCustomd(unittest.TestCase):
+
@parameterized.expand([ADD_GUIDANCE_CUSTOM_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = AddGuidanceSignalDeepEditd(**arguments)
@@ -225,6 +227,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestAddInitialSeedPointMissingLabelsd(unittest.TestCase):
+
@parameterized.expand([ADD_INITIAL_POINT_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
seed = 0
@@ -235,6 +238,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestAddRandomGuidanceCustomd(unittest.TestCase):
+
@parameterized.expand([ADD_RANDOM_GUIDANCE_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = AddRandomGuidanceDeepEditd(**arguments)
@@ -244,6 +248,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestDiscardAddGuidanced(unittest.TestCase):
+
@parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = DiscardAddGuidanced(**arguments)
@@ -252,6 +257,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestFindAllValidSlicesMissingLabelsd(unittest.TestCase):
+
@parameterized.expand([FIND_SLICE_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = FindAllValidSlicesMissingLabelsd(**arguments)
@@ -260,6 +266,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestFindDiscrepancyRegionsCustomd(unittest.TestCase):
+
@parameterized.expand([FIND_DISCREPANCY_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = FindDiscrepancyRegionsDeepEditd(**arguments)
@@ -268,6 +275,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestNormalizeLabelsDatasetd(unittest.TestCase):
+
@parameterized.expand([NormalizeLabelsDatasetd_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = NormalizeLabelsInDatasetd(**arguments)
@@ -276,6 +284,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestResizeGuidanceMultipleLabelCustomd(unittest.TestCase):
+
@parameterized.expand([RESIZE_GUIDANCE_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = ResizeGuidanceMultipleLabelDeepEditd(**arguments)
@@ -284,6 +293,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestSingleLabelSelectiond(unittest.TestCase):
+
@parameterized.expand([SingleLabelSelectiond_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = SingleLabelSelectiond(**arguments)
@@ -292,6 +302,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestSplitPredsLabeld(unittest.TestCase):
+
@parameterized.expand([SplitPredsLabeld_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = SplitPredsLabeld(**arguments)
diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py
index d8a412ade9..b8d630960c 100644
--- a/tests/test_deepgrow_dataset.py
+++ b/tests/test_deepgrow_dataset.py
@@ -51,6 +51,7 @@
class TestCreateDataset(unittest.TestCase):
+
def setUp(self):
set_determinism(1)
self.tempdir = tempfile.mkdtemp()
diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py
index 7cdbeed9f9..35759699f8 100644
--- a/tests/test_deepgrow_interaction.py
+++ b/tests/test_deepgrow_interaction.py
@@ -38,6 +38,7 @@ def add_one(engine):
class TestInteractions(unittest.TestCase):
+
def run_interaction(self, train, compose):
data = [{"image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2))} for _ in range(5)]
network = torch.nn.Linear(2, 2)
diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py
index 1328e13439..a491a8004b 100644
--- a/tests/test_deepgrow_transforms.py
+++ b/tests/test_deepgrow_transforms.py
@@ -337,6 +337,7 @@
class TestFindAllValidSlicesd(unittest.TestCase):
+
@parameterized.expand([FIND_SLICE_TEST_CASE_1, FIND_SLICE_TEST_CASE_2])
def test_correct_results(self, arguments, input_data, expected_result):
result = FindAllValidSlicesd(**arguments)(input_data)
@@ -344,6 +345,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestSpatialCropForegroundd(unittest.TestCase):
+
@parameterized.expand([CROP_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = SpatialCropForegroundd(**arguments)(input_data)
@@ -368,6 +370,7 @@ def test_foreground_position(self, arguments, input_data, _):
class TestAddInitialSeedPointd(unittest.TestCase):
+
@parameterized.expand([ADD_INITIAL_POINT_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
seed = 0
@@ -378,6 +381,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestAddGuidanceSignald(unittest.TestCase):
+
@parameterized.expand([ADD_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = AddGuidanceSignald(**arguments)(input_data)
@@ -385,6 +389,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestFindDiscrepancyRegionsd(unittest.TestCase):
+
@parameterized.expand([FIND_DISCREPANCY_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = FindDiscrepancyRegionsd(**arguments)(input_data)
@@ -392,6 +397,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestAddRandomGuidanced(unittest.TestCase):
+
@parameterized.expand([ADD_RANDOM_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
seed = 0
@@ -402,6 +408,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestAddGuidanceFromPointsd(unittest.TestCase):
+
@parameterized.expand(
[
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1,
@@ -419,6 +426,7 @@ def test_correct_results(self, arguments, input_data, expected_pos, expected_neg
class TestSpatialCropGuidanced(unittest.TestCase):
+
@parameterized.expand(
[SPATIAL_CROP_GUIDANCE_TEST_CASE_1, SPATIAL_CROP_GUIDANCE_TEST_CASE_2, SPATIAL_CROP_GUIDANCE_TEST_CASE_3]
)
@@ -428,6 +436,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestResizeGuidanced(unittest.TestCase):
+
@parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = ResizeGuidanced(**arguments)(input_data)
@@ -435,6 +444,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestRestoreLabeld(unittest.TestCase):
+
@parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2])
def test_correct_results(self, arguments, input_data, expected_result):
result = RestoreLabeld(**arguments)(input_data)
@@ -442,6 +452,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestFetch2DSliced(unittest.TestCase):
+
@parameterized.expand([FETCH_2D_SLICE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = Fetch2DSliced(**arguments)(input_data)
diff --git a/tests/test_delete_itemsd.py b/tests/test_delete_itemsd.py
index 1ec77f29fd..c57184cd9f 100644
--- a/tests/test_delete_itemsd.py
+++ b/tests/test_delete_itemsd.py
@@ -28,6 +28,7 @@
class TestDeleteItemsd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_memory(self, input_param, expected_key_size):
input_data = {"image": {}} if "sep" in input_param else {}
diff --git a/tests/test_denseblock.py b/tests/test_denseblock.py
index c14ca2ae7a..b741582422 100644
--- a/tests/test_denseblock.py
+++ b/tests/test_denseblock.py
@@ -20,6 +20,7 @@
class TestDenseBlock2D(TorchImageTestCase2D):
+
def test_block_empty(self):
block = DenseBlock([])
out = block(self.imt)
@@ -36,6 +37,7 @@ def test_block_conv(self):
class TestDenseBlock3D(TorchImageTestCase3D):
+
def test_block_conv(self):
conv1 = nn.Conv3d(self.input_channels, self.output_channels, 3, padding=1)
conv2 = nn.Conv3d(self.input_channels + self.output_channels, self.input_channels, 3, padding=1)
@@ -52,6 +54,7 @@ def test_block_conv(self):
class TestConvDenseBlock2D(TorchImageTestCase2D):
+
def test_block_empty(self):
conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=[])
out = conv(self.imt)
@@ -79,6 +82,7 @@ def test_block2(self):
class TestConvDenseBlock3D(TorchImageTestCase3D):
+
def test_block_empty(self):
conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=[])
out = conv(self.imt)
diff --git a/tests/test_densenet.py b/tests/test_densenet.py
index 1b44baf0c2..ee4be9003b 100644
--- a/tests/test_densenet.py
+++ b/tests/test_densenet.py
@@ -79,6 +79,7 @@
class TestPretrainedDENSENET(unittest.TestCase):
+
@parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2])
@skip_if_quick
def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape):
@@ -103,6 +104,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape):
class TestDENSENET(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_densenet_shape(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py
index 5d511f3821..3171a67e2c 100644
--- a/tests/test_deprecated.py
+++ b/tests/test_deprecated.py
@@ -18,6 +18,7 @@
class TestDeprecatedRC(unittest.TestCase):
+
def setUp(self):
self.test_version_rc = "0.6.0rc1"
self.test_version = "0.6.0"
@@ -61,6 +62,7 @@ def foo3():
class TestDeprecated(unittest.TestCase):
+
def setUp(self):
self.test_version = "0.5.3+96.g1fa03c2.dirty"
self.prev_version = "0.4.3+96.g1fa03c2.dirty"
@@ -142,6 +144,7 @@ def test_meth_warning1(self):
"""Test deprecated decorator with just `since` set."""
class Foo5:
+
@deprecated(since=self.prev_version, version_val=self.test_version)
def meth1(self):
pass
@@ -152,6 +155,7 @@ def test_meth_except1(self):
"""Test deprecated decorator with just `since` set."""
class Foo6:
+
@deprecated(version_val=self.test_version)
def meth1(self):
pass
@@ -389,6 +393,7 @@ def test_deprecated_arg_default_errors(self):
# since > replaced
def since_grater_than_replaced():
+
@deprecated_arg_default(
"b",
old_default="a",
@@ -404,6 +409,7 @@ def foo(a, b=None):
# argname doesnt exist
def argname_doesnt_exist():
+
@deprecated_arg_default(
"other", old_default="a", new_default="b", since=self.test_version, version_val=self.test_version
)
@@ -414,6 +420,7 @@ def foo(a, b=None):
# argname has no default
def argname_has_no_default():
+
@deprecated_arg_default(
"a",
old_default="a",
@@ -429,6 +436,7 @@ def foo(a):
# new default is used but version < replaced
def argname_was_replaced_before_specified_version():
+
@deprecated_arg_default(
"a",
old_default="a",
diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py
index 105d3a4ace..e2efefeb77 100644
--- a/tests/test_detect_envelope.py
+++ b/tests/test_detect_envelope.py
@@ -116,6 +116,7 @@
@SkipIfNoModule("torch.fft")
class TestDetectEnvelope(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_1D_SINE,
@@ -151,6 +152,7 @@ def test_value_error(self, arguments, image, method):
@SkipIfModule("torch.fft")
class TestHilbertTransformNoFFTMod(unittest.TestCase):
+
def test_no_fft_module_error(self):
self.assertRaises(OptionalImportError, DetectEnvelope(), np.random.rand(1, 10))
diff --git a/tests/test_detection_coco_metrics.py b/tests/test_detection_coco_metrics.py
index 780031ee0c..a85eb37db7 100644
--- a/tests/test_detection_coco_metrics.py
+++ b/tests/test_detection_coco_metrics.py
@@ -23,6 +23,7 @@
class TestCOCOMetrics(unittest.TestCase):
+
def test_coco_run(self):
coco_metric = COCOMetric(classes=["c0", "c1", "c2"], iou_list=[0.1], max_detection=[10])
diff --git a/tests/test_detector_boxselector.py b/tests/test_detector_boxselector.py
index 8cc9b15911..326ecd5773 100644
--- a/tests/test_detector_boxselector.py
+++ b/tests/test_detector_boxselector.py
@@ -56,6 +56,7 @@
class TestBoxSelector(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_box_selector(self, input_param, boxes, logits, image_shape, expected_results):
box_selector = BoxSelector(**input_param)
diff --git a/tests/test_detector_utils.py b/tests/test_detector_utils.py
index 41716934b5..352e1c2faf 100644
--- a/tests/test_detector_utils.py
+++ b/tests/test_detector_utils.py
@@ -79,6 +79,7 @@
class TestDetectorUtils(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_detector_utils(self, input_param, input_shape, expected_shape):
size_divisible = 32 * ensure_tuple(input_param["conv1_t_stride"])[0]
diff --git a/tests/test_dev_collate.py b/tests/test_dev_collate.py
index 97028f2597..44c4d2c598 100644
--- a/tests/test_dev_collate.py
+++ b/tests/test_dev_collate.py
@@ -36,6 +36,7 @@
class DevCollateTest(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_dev_collate(self, inputs, msg):
with self.assertLogs(level=logging.CRITICAL) as log:
diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py
index 58b9f4c191..225618ed2c 100644
--- a/tests/test_dice_ce_loss.py
+++ b/tests/test_dice_ce_loss.py
@@ -86,6 +86,7 @@
class TestDiceCELoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
diceceloss = DiceCELoss(**input_param)
diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py
index 845ef40cd5..13899da003 100644
--- a/tests/test_dice_focal_loss.py
+++ b/tests/test_dice_focal_loss.py
@@ -22,6 +22,7 @@
class TestDiceFocalLoss(unittest.TestCase):
+
def test_result_onehot_target_include_bg(self):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py
index 370d2dd5af..14aa6ec241 100644
--- a/tests/test_dice_loss.py
+++ b/tests/test_dice_loss.py
@@ -168,6 +168,7 @@
class TestDiceLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = DiceLoss(**input_param).forward(**input_data)
diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py
index 05dfab95fb..93df77cc51 100644
--- a/tests/test_diffusion_loss.py
+++ b/tests/test_diffusion_loss.py
@@ -79,6 +79,7 @@
class TestDiffusionLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = DiffusionLoss(**input_param).forward(**input_data)
diff --git a/tests/test_dints_cell.py b/tests/test_dints_cell.py
index 21cef39d68..13990da373 100644
--- a/tests/test_dints_cell.py
+++ b/tests/test_dints_cell.py
@@ -98,6 +98,7 @@
class TestCell(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_2D + TEST_CASES_3D)
def test_cell_3d(self, input_param, ops, weight, input_shape, expected_shape):
net = Cell(**input_param)
diff --git a/tests/test_dints_mixop.py b/tests/test_dints_mixop.py
index 09d2e7a423..683a8d1005 100644
--- a/tests/test_dints_mixop.py
+++ b/tests/test_dints_mixop.py
@@ -61,6 +61,7 @@
class TestMixOP(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D)
def test_mixop_3d(self, input_param, ops, weight, input_shape, expected_shape):
net = MixedOp(ops=Cell.OPS3D, **input_param)
diff --git a/tests/test_dints_network.py b/tests/test_dints_network.py
index 532c31886b..5ee4db7a4e 100644
--- a/tests/test_dints_network.py
+++ b/tests/test_dints_network.py
@@ -115,6 +115,7 @@
@skip_if_quick
class TestDints(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)
def test_dints_inference(self, dints_grid_params, dints_params, input_shape, expected_shape):
grid = TopologySearch(**dints_grid_params)
@@ -155,6 +156,7 @@ def test_dints_search(self, dints_grid_params, dints_params, input_shape, expect
@SkipIfBeforePyTorchVersion((1, 9))
class TestDintsTS(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)
def test_script(self, dints_grid_params, dints_params, input_shape, _):
grid = TopologyInstance(**dints_grid_params)
diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py
index 62635e286e..f615605e56 100644
--- a/tests/test_discriminator.py
+++ b/tests/test_discriminator.py
@@ -42,6 +42,7 @@
class TestDiscriminator(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_data, expected_shape):
net = Discriminator(**input_param)
diff --git a/tests/test_distance_transform_edt.py b/tests/test_distance_transform_edt.py
index 83b9348348..cf5c253c0c 100644
--- a/tests/test_distance_transform_edt.py
+++ b/tests/test_distance_transform_edt.py
@@ -146,6 +146,7 @@
class TestDistanceTransformEDT(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_scipy_transform(self, input, expected_output):
transform = DistanceTransformEDT()
diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py
index 696bcfc78f..555f7dc250 100644
--- a/tests/test_download_and_extract.py
+++ b/tests/test_download_and_extract.py
@@ -24,6 +24,7 @@
class TestDownloadAndExtract(unittest.TestCase):
+
@skip_if_quick
def test_actions(self):
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_download_url_yandex.py b/tests/test_download_url_yandex.py
index a08105a93f..54d39b06ff 100644
--- a/tests/test_download_url_yandex.py
+++ b/tests/test_download_url_yandex.py
@@ -29,6 +29,7 @@
class TestDownloadUrlYandex(unittest.TestCase):
+
@unittest.skip("data source unstable")
def test_verify(self):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py
index cd40be4306..34afa248ad 100644
--- a/tests/test_downsample_block.py
+++ b/tests/test_downsample_block.py
@@ -37,6 +37,7 @@
class TestMaxAvgPool(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = MaxAvgPool(**input_param)
diff --git a/tests/test_drop_path.py b/tests/test_drop_path.py
index ab2150e548..1b9974791a 100644
--- a/tests/test_drop_path.py
+++ b/tests/test_drop_path.py
@@ -28,6 +28,7 @@
class TestDropPath(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape):
im = torch.rand(input_shape)
diff --git a/tests/test_ds_loss.py b/tests/test_ds_loss.py
index de7aec1ced..daa4ed1e5e 100644
--- a/tests/test_ds_loss.py
+++ b/tests/test_ds_loss.py
@@ -135,6 +135,7 @@
class TestDSLossDiceCE(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_DICECE)
def test_result(self, input_param, input_param2, input_data, expected_val):
diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)
@@ -160,6 +161,7 @@ def test_script(self):
@SkipIfBeforePyTorchVersion((1, 11))
class TestDSLossDiceCE2(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_DICECE2)
def test_result(self, input_param, input_param2, input_data, expected_val):
diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)
@@ -169,6 +171,7 @@ def test_result(self, input_param, input_param2, input_data, expected_val):
@SkipIfBeforePyTorchVersion((1, 11))
class TestDSLossDice(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_DICE)
def test_result(self, input_param, input_data, expected_val):
loss = DeepSupervisionLoss(DiceLoss(**input_param))
@@ -178,6 +181,7 @@ def test_result(self, input_param, input_data, expected_val):
@SkipIfBeforePyTorchVersion((1, 11))
class TestDSLossDiceFocal(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_DICEFOCAL)
def test_result(self, input_param, input_data, expected_val):
loss = DeepSupervisionLoss(DiceFocalLoss(**input_param))
diff --git a/tests/test_dvf2ddf.py b/tests/test_dvf2ddf.py
index f18b5b7297..b385b897e5 100644
--- a/tests/test_dvf2ddf.py
+++ b/tests/test_dvf2ddf.py
@@ -42,6 +42,7 @@
class TestDVF2DDF(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py
index 247da14b7d..b0137ae245 100644
--- a/tests/test_dynunet.py
+++ b/tests/test_dynunet.py
@@ -109,6 +109,7 @@
class TestDynUNet(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_DYNUNET_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = DynUNet(**input_param).to(device)
@@ -128,6 +129,7 @@ def test_script(self):
@skip_if_no_cuda
@skip_if_windows
class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase):
+
def setUp(self):
try:
layer = InstanceNorm3dNVFuser(num_features=1, affine=False).to("cuda:0")
@@ -161,6 +163,7 @@ def test_consistency(self, input_param, input_shape, _):
class TestDynUNetDeepSupervision(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_DEEP_SUPERVISION)
def test_shape(self, input_param, input_shape, expected_shape):
net = DynUNet(**input_param).to(device)
diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py
index b34ccb31ba..4d9e06670b 100644
--- a/tests/test_dynunet_block.py
+++ b/tests/test_dynunet_block.py
@@ -73,6 +73,7 @@
class TestResBasicBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_RES_BASIC_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
for net in [UnetResBlock(**input_param), UnetBasicBlock(**input_param)]:
@@ -96,6 +97,7 @@ def test_script(self):
class TestUpBlock(unittest.TestCase):
+
@parameterized.expand(TEST_UP_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape, skip_shape):
net = UnetUpBlock(**input_param)
diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py
index 5bdad5a568..c16526eaa3 100644
--- a/tests/test_efficientnet.py
+++ b/tests/test_efficientnet.py
@@ -248,6 +248,7 @@ def make_shape_cases(
class TestEFFICIENTNET(unittest.TestCase):
+
@parameterized.expand(CASES_1D + CASES_2D + CASES_3D + CASES_VARIATIONS)
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -376,6 +377,7 @@ def test_script(self):
class TestExtractFeatures(unittest.TestCase):
+
@parameterized.expand(CASE_EXTRACT_FEATURES)
def test_shape(self, input_param, input_shape, expected_shapes):
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py
index 40a4a72dd5..ad81d35d52 100644
--- a/tests/test_ensemble_evaluator.py
+++ b/tests/test_ensemble_evaluator.py
@@ -26,11 +26,13 @@
class TestEnsembleEvaluator(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_content(self, pred_keys):
device = torch.device("cpu:0")
class TestDataset(torch.utils.data.Dataset):
+
def __len__(self):
return 8
@@ -40,6 +42,7 @@ def __getitem__(self, index):
val_loader = torch.utils.data.DataLoader(TestDataset())
class TestNet(torch.nn.Module):
+
def __init__(self, func):
super().__init__()
self.func = func
diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py
index 027b18b7dd..0c9ad5869e 100644
--- a/tests/test_ensure_channel_first.py
+++ b/tests/test_ensure_channel_first.py
@@ -46,6 +46,7 @@
class TestEnsureChannelFirst(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
@unittest.skipUnless(has_itk, "itk not installed")
def test_load_nifti(self, input_param, filenames, original_channel_dim):
diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py
index 08e2709641..63a437894b 100644
--- a/tests/test_ensure_channel_firstd.py
+++ b/tests/test_ensure_channel_firstd.py
@@ -32,6 +32,7 @@
class TestEnsureChannelFirstd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_load_nifti(self, input_param, filenames, original_channel_dim):
if original_channel_dim is None:
diff --git a/tests/test_ensure_tuple.py b/tests/test_ensure_tuple.py
index dc6649ec4c..ec8c92785a 100644
--- a/tests/test_ensure_tuple.py
+++ b/tests/test_ensure_tuple.py
@@ -37,6 +37,7 @@
class TestEnsureTuple(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input, expected_value, wrap_array=False):
result = ensure_tuple(input, wrap_array)
diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py
index 7d6b7ca586..00b01898b3 100644
--- a/tests/test_ensure_type.py
+++ b/tests/test_ensure_type.py
@@ -22,6 +22,7 @@
class TestEnsureType(unittest.TestCase):
+
def test_array_input(self):
test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])]
if torch.cuda.is_available():
diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py
index 4fa942e742..09aa1f04b5 100644
--- a/tests/test_ensure_typed.py
+++ b/tests/test_ensure_typed.py
@@ -22,6 +22,7 @@
class TestEnsureTyped(unittest.TestCase):
+
def test_array_input(self):
test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])]
if torch.cuda.is_available():
diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py
index 5a63fc05af..cd3119f91c 100644
--- a/tests/test_enum_bound_interp.py
+++ b/tests/test_enum_bound_interp.py
@@ -22,6 +22,7 @@
@skip_if_no_cpp_extension
class TestEnumBoundInterp(unittest.TestCase):
+
def test_bound(self):
self.assertEqual(str(b.replicate), "BoundType.replicate")
self.assertEqual(str(b.nearest), "BoundType.replicate")
diff --git a/tests/test_eval_mode.py b/tests/test_eval_mode.py
index 8458753e1f..b40bb78327 100644
--- a/tests/test_eval_mode.py
+++ b/tests/test_eval_mode.py
@@ -19,6 +19,7 @@
class TestEvalMode(unittest.TestCase):
+
def test_eval_mode(self):
t = torch.rand(1, 1, 4, 4)
p = torch.nn.Conv2d(1, 1, 3)
diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py
index f338944daa..d6d26c7e23 100644
--- a/tests/test_evenly_divisible_all_gather_dist.py
+++ b/tests/test_evenly_divisible_all_gather_dist.py
@@ -21,6 +21,7 @@
class DistributedEvenlyDivisibleAllGather(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_data(self):
self._run()
diff --git a/tests/test_factorized_increase.py b/tests/test_factorized_increase.py
index f7642ff357..b082c70090 100644
--- a/tests/test_factorized_increase.py
+++ b/tests/test_factorized_increase.py
@@ -25,6 +25,7 @@
class TestFactInc(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D)
def test_factorized_increase_3d(self, input_param, input_shape, expected_shape):
net = FactorizedIncreaseBlock(**input_param)
diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py
index 224a0cb351..5e879c3cb5 100644
--- a/tests/test_factorized_reduce.py
+++ b/tests/test_factorized_reduce.py
@@ -25,6 +25,7 @@
class TestFactRed(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D)
def test_factorized_reduce_3d(self, input_param, input_shape, expected_shape):
net = FactorizedReduceBlock(**input_param)
diff --git a/tests/test_fastmri_reader.py b/tests/test_fastmri_reader.py
index b15bd4b6a2..af2eed7db5 100644
--- a/tests/test_fastmri_reader.py
+++ b/tests/test_fastmri_reader.py
@@ -65,6 +65,7 @@
class TestMRIUtils(unittest.TestCase):
+
@parameterized.expand([TEST_CASE1, TEST_CASE2])
def test_get_data(self, test_data, test_res, test_meta):
reader = FastMRIReader()
diff --git a/tests/test_fft_utils.py b/tests/test_fft_utils.py
index 971df2b411..7c7035770a 100644
--- a/tests/test_fft_utils.py
+++ b/tests/test_fft_utils.py
@@ -42,6 +42,7 @@
class TestFFT(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test(self, test_data, res_data):
result = fftn_centered(test_data, spatial_dims=2, is_complex=False)
diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py
index 7d88bb7ee9..a28c491333 100644
--- a/tests/test_fg_bg_to_indices.py
+++ b/tests/test_fg_bg_to_indices.py
@@ -72,6 +72,7 @@
class TestFgBgToIndices(unittest.TestCase):
+
@parameterized.expand(TESTS_CASES)
def test_type_shape(self, input_data, label, image, expected_fg, expected_bg):
fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image)
diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py
index d0d1ae5fb6..c6dd2059f4 100644
--- a/tests/test_fg_bg_to_indicesd.py
+++ b/tests/test_fg_bg_to_indicesd.py
@@ -67,6 +67,7 @@
class TestFgBgToIndicesd(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_type_shape(self, input_data, data, expected_fg, expected_bg):
result = FgBgToIndicesd(**input_data)(data)
diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py
index 93e2027575..27e2d98c7d 100644
--- a/tests/test_file_basename.py
+++ b/tests/test_file_basename.py
@@ -20,6 +20,7 @@
class TestFilename(unittest.TestCase):
+
def test_value(self):
with tempfile.TemporaryDirectory() as tempdir:
output_tmp = os.path.join(tempdir, "output")
diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py
index 65c59d49eb..241f7f8254 100644
--- a/tests/test_fill_holes.py
+++ b/tests/test_fill_holes.py
@@ -195,6 +195,7 @@
class TestFillHoles(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, args, input_image, expected):
converter = FillHoles(**args)
diff --git a/tests/test_fill_holesd.py b/tests/test_fill_holesd.py
index 3f98dab1bf..28c17b00ac 100644
--- a/tests/test_fill_holesd.py
+++ b/tests/test_fill_holesd.py
@@ -196,6 +196,7 @@
class TestFillHoles(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, args, input_image, expected):
key = CommonKeys.IMAGE
diff --git a/tests/test_fl_exchange_object.py b/tests/test_fl_exchange_object.py
index 293f9d518b..dab4eae037 100644
--- a/tests/test_fl_exchange_object.py
+++ b/tests/test_fl_exchange_object.py
@@ -46,6 +46,7 @@
@SkipIfNoModule("torchvision")
@SkipIfNoModule("ignite")
class TestFLExchangeObject(unittest.TestCase):
+
@parameterized.expand([TEST_INIT_1, TEST_INIT_2])
def test_init(self, input_params, expected_str):
eo = ExchangeObject(**input_params)
diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py
index ca781ff166..54bec24b98 100644
--- a/tests/test_fl_monai_algo.py
+++ b/tests/test_fl_monai_algo.py
@@ -181,6 +181,7 @@
@SkipIfNoModule("ignite")
@SkipIfNoModule("mlflow")
class TestFLMonaiAlgo(unittest.TestCase):
+
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4])
def test_train(self, input_params):
# initialize algo
diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py
index 1302ab6618..d8dbfa5339 100644
--- a/tests/test_fl_monai_algo_dist.py
+++ b/tests/test_fl_monai_algo_dist.py
@@ -32,6 +32,7 @@
@SkipIfNoModule("ignite")
@SkipIfBeforePyTorchVersion((1, 11, 1))
class TestFLMonaiAlgo(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2, init_method="no_init")
@skip_if_no_cuda
def test_train(self):
diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py
index 307b3f539c..6e58f8af88 100644
--- a/tests/test_fl_monai_algo_stats.py
+++ b/tests/test_fl_monai_algo_stats.py
@@ -64,6 +64,7 @@
@SkipIfNoModule("ignite")
class TestFLMonaiAlgo(unittest.TestCase):
+
@parameterized.expand([TEST_GET_DATA_STATS_1, TEST_GET_DATA_STATS_2, TEST_GET_DATA_STATS_3])
def test_get_data_stats(self, input_params):
# initialize algo
diff --git a/tests/test_flatten_sub_keysd.py b/tests/test_flatten_sub_keysd.py
index 997f203870..1a642e5fc4 100644
--- a/tests/test_flatten_sub_keysd.py
+++ b/tests/test_flatten_sub_keysd.py
@@ -46,6 +46,7 @@
class TestFlattenSubKeysd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_dict(self, params, input_data, expected):
result = FlattenSubKeysd(**params)(input_data)
diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py
index 1d831f0976..404855c9a8 100644
--- a/tests/test_flexible_unet.py
+++ b/tests/test_flexible_unet.py
@@ -35,6 +35,7 @@
class DummyEncoder(BaseEncoder):
+
@classmethod
def get_encoder_parameters(cls):
basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False}
@@ -364,6 +365,7 @@ def make_error_case():
@skip_if_quick
class TestFLEXIBLEUNET(unittest.TestCase):
+
@parameterized.expand(CASES_2D + CASES_3D + CASES_VARIATIONS)
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -404,6 +406,7 @@ def test_error_raise(self, input_param):
class TestFlexUNetEncoderRegister(unittest.TestCase):
+
@parameterized.expand(CASE_REGISTER_ENCODER)
def test_regist(self, encoder):
tmp_backbone = FlexUNetEncoderRegister()
diff --git a/tests/test_flip.py b/tests/test_flip.py
index d7df55fde0..789ec86920 100644
--- a/tests/test_flip.py
+++ b/tests/test_flip.py
@@ -34,6 +34,7 @@
class TestFlip(NumpyImageTestCase2D):
+
@parameterized.expand(INVALID_CASES)
def test_invalid_inputs(self, _, spatial_axis, raises):
with self.assertRaises(raises):
diff --git a/tests/test_flipd.py b/tests/test_flipd.py
index 19f9ed0882..277f387051 100644
--- a/tests/test_flipd.py
+++ b/tests/test_flipd.py
@@ -35,6 +35,7 @@
class TestFlipd(NumpyImageTestCase2D):
+
@parameterized.expand(INVALID_CASES)
def test_invalid_cases(self, _, spatial_axis, raises):
with self.assertRaises(raises):
diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py
index 57df6a3460..de8d625058 100644
--- a/tests/test_focal_loss.py
+++ b/tests/test_focal_loss.py
@@ -79,6 +79,7 @@
class TestFocalLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
focal_loss = FocalLoss(**input_param)
diff --git a/tests/test_folder_layout.py b/tests/test_folder_layout.py
index d6d4bdf679..6f72eee51f 100644
--- a/tests/test_folder_layout.py
+++ b/tests/test_folder_layout.py
@@ -60,6 +60,7 @@
class TestFolderLayout(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_value(self, con_params, f_params, expected):
fname = FolderLayout(**con_params).filename(**f_params)
diff --git a/tests/test_foreground_mask.py b/tests/test_foreground_mask.py
index eb59ae2db6..1aa54f4d3a 100644
--- a/tests/test_foreground_mask.py
+++ b/tests/test_foreground_mask.py
@@ -81,6 +81,7 @@
@unittest.skipUnless(has_skimage, "Requires sci-kit image")
class TestForegroundMask(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_foreground_mask(self, in_type, arguments, image, mask):
input_image = in_type(image)
diff --git a/tests/test_foreground_maskd.py b/tests/test_foreground_maskd.py
index 24cb233c30..dc7b6cfb24 100644
--- a/tests/test_foreground_maskd.py
+++ b/tests/test_foreground_maskd.py
@@ -89,6 +89,7 @@
@unittest.skipUnless(has_skimage, "Requires sci-kit image")
class TestForegroundMaskd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_foreground_mask(self, in_type, arguments, data_dict, mask):
data_dict[arguments["keys"]] = in_type(data_dict[arguments["keys"]])
diff --git a/tests/test_fourier.py b/tests/test_fourier.py
index 3613db989f..177fc280f7 100644
--- a/tests/test_fourier.py
+++ b/tests/test_fourier.py
@@ -28,6 +28,7 @@
@SkipIfBeforePyTorchVersion((1, 8))
@SkipIfNoModule("torch.fft")
class TestFourier(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_fpn_block.py b/tests/test_fpn_block.py
index c6121c5b98..969800e80a 100644
--- a/tests/test_fpn_block.py
+++ b/tests/test_fpn_block.py
@@ -44,6 +44,7 @@
class TestFPNBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_fpn_block(self, input_param, input_shape, expected_shape):
net = FeaturePyramidNetwork(**input_param)
@@ -67,6 +68,7 @@ def test_script(self, input_param, input_shape, expected_shape):
@unittest.skipUnless(has_torchvision, "Requires torchvision")
class TestFPN(unittest.TestCase):
+
@parameterized.expand(TEST_CASES2)
def test_fpn(self, input_param, input_shape, expected_shape):
net = _resnet_fpn_extractor(backbone=resnet50(), spatial_dims=input_param["spatial_dims"], returned_layers=[1])
diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py
index 29594ed98a..1bea4ed1b5 100644
--- a/tests/test_freeze_layers.py
+++ b/tests/test_freeze_layers.py
@@ -27,6 +27,7 @@
class TestModuleState(unittest.TestCase):
+
def tearDown(self):
set_determinism(None)
diff --git a/tests/test_from_engine_hovernet.py b/tests/test_from_engine_hovernet.py
index 227fa66baa..7d1a784466 100644
--- a/tests/test_from_engine_hovernet.py
+++ b/tests/test_from_engine_hovernet.py
@@ -28,6 +28,7 @@
class TestFromEngineHovernet(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_results(self, input, expected):
output = from_engine_hovernet(keys=["A", "B"], nested_key="C")(input)
diff --git a/tests/test_fullyconnectednet.py b/tests/test_fullyconnectednet.py
index 94fc4caa6e..863d1399a9 100644
--- a/tests/test_fullyconnectednet.py
+++ b/tests/test_fullyconnectednet.py
@@ -42,6 +42,7 @@
class TestFullyConnectedNet(unittest.TestCase):
+
def setUp(self):
self.batch_size = 10
self.inSize = 10
diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py
index b98507b793..689d8088f9 100644
--- a/tests/test_gaussian.py
+++ b/tests/test_gaussian.py
@@ -224,6 +224,7 @@
class TestGaussian1d(unittest.TestCase):
+
def test_gaussian(self):
np.testing.assert_allclose(
gaussian_1d(0.5, 8),
diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py
index 1beee579e8..4ab689c565 100644
--- a/tests/test_gaussian_filter.py
+++ b/tests/test_gaussian_filter.py
@@ -35,6 +35,7 @@
class TestGaussianFilterBackprop(unittest.TestCase):
+
def code_to_run(self, input_args):
input_dims = input_args.get("dims", (2, 3, 8))
device = (
@@ -94,6 +95,7 @@ def test_train_slow(self, input_args):
class GaussianFilterTestCase(unittest.TestCase):
+
def test_1d(self):
a = torch.ones(1, 8, 10)
g = GaussianFilter(1, 3, 3).to(torch.device("cpu:0"))
diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py
index 2509a4fc26..392a7b376b 100644
--- a/tests/test_gaussian_sharpen.py
+++ b/tests/test_gaussian_sharpen.py
@@ -82,6 +82,7 @@
class TestGaussianSharpen(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = GaussianSharpen(**arguments)(image)
diff --git a/tests/test_gaussian_sharpend.py b/tests/test_gaussian_sharpend.py
index 75ea915d96..15b219fd2c 100644
--- a/tests/test_gaussian_sharpend.py
+++ b/tests/test_gaussian_sharpend.py
@@ -82,6 +82,7 @@
class TestGaussianSharpend(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = GaussianSharpend(**arguments)(image)
diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py
index 38b29bbd17..9f99ebe0f8 100644
--- a/tests/test_gaussian_smooth.py
+++ b/tests/test_gaussian_smooth.py
@@ -86,6 +86,7 @@
class TestGaussianSmooth(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = GaussianSmooth(**arguments)(image)
diff --git a/tests/test_gaussian_smoothd.py b/tests/test_gaussian_smoothd.py
index 8702c073c8..a6de4a159b 100644
--- a/tests/test_gaussian_smoothd.py
+++ b/tests/test_gaussian_smoothd.py
@@ -86,6 +86,7 @@
class TestGaussianSmoothd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = GaussianSmoothd(**arguments)(image)
diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py
index 29f2d0096b..f0a419dcf5 100644
--- a/tests/test_gdsdataset.py
+++ b/tests/test_gdsdataset.py
@@ -64,6 +64,7 @@
class _InplaceXform(Transform):
+
def __call__(self, data):
data[0] = data[0] + 1
return data
@@ -73,6 +74,7 @@ def __call__(self, data):
@unittest.skipUnless(has_nib, "Requires nibabel package.")
@unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.")
class TestDataset(unittest.TestCase):
+
def test_cache(self):
"""testing no inplace change to the hashed item"""
for p in TEST_NDARRAYS[:2]:
diff --git a/tests/test_generalized_dice_focal_loss.py b/tests/test_generalized_dice_focal_loss.py
index 33f6653212..8a0a80865e 100644
--- a/tests/test_generalized_dice_focal_loss.py
+++ b/tests/test_generalized_dice_focal_loss.py
@@ -21,6 +21,7 @@
class TestGeneralizedDiceFocalLoss(unittest.TestCase):
+
def test_result_onehot_target_include_bg(self):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py
index d8ba496d03..7499507129 100644
--- a/tests/test_generalized_dice_loss.py
+++ b/tests/test_generalized_dice_loss.py
@@ -142,6 +142,7 @@
class TestGeneralizedDiceLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = GeneralizedDiceLoss(**input_param).forward(**input_data)
diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py
index 7b85fdc5b6..6b9d57e831 100644
--- a/tests/test_generalized_wasserstein_dice_loss.py
+++ b/tests/test_generalized_wasserstein_dice_loss.py
@@ -24,6 +24,7 @@
class TestGeneralizedWassersteinDiceLoss(unittest.TestCase):
+
def test_bin_seg_2d(self):
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
@@ -160,6 +161,7 @@ def test_convergence(self):
# define a model with one layer
class OnelayerNet(nn.Module):
+
def __init__(self):
super().__init__()
self.layer = nn.Linear(num_voxels, num_voxels * num_classes)
diff --git a/tests/test_generate_distance_map.py b/tests/test_generate_distance_map.py
index 724a335e1a..42f5664647 100644
--- a/tests/test_generate_distance_map.py
+++ b/tests/test_generate_distance_map.py
@@ -36,6 +36,7 @@
class TestGenerateDistanceMap(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, mask, probmap, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_distance_mapd.py b/tests/test_generate_distance_mapd.py
index 17c5aa782b..2bddadf5b8 100644
--- a/tests/test_generate_distance_mapd.py
+++ b/tests/test_generate_distance_mapd.py
@@ -55,6 +55,7 @@
class TestGenerateDistanceMapd(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, mask, border_map, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_instance_border.py b/tests/test_generate_instance_border.py
index 8634bb7d77..fc1035dfe5 100644
--- a/tests/test_generate_instance_border.py
+++ b/tests/test_generate_instance_border.py
@@ -34,6 +34,7 @@
class TestGenerateInstanceBorder(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, mask, hover_map, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_instance_borderd.py b/tests/test_generate_instance_borderd.py
index fc81e8f87c..cdfbee4193 100644
--- a/tests/test_generate_instance_borderd.py
+++ b/tests/test_generate_instance_borderd.py
@@ -44,6 +44,7 @@
class TestGenerateInstanceBorderd(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, mask, hover_map, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_instance_centroid.py b/tests/test_generate_instance_centroid.py
index f9fdc602a9..6b4d533401 100644
--- a/tests/test_generate_instance_centroid.py
+++ b/tests/test_generate_instance_centroid.py
@@ -41,6 +41,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestGenerateInstanceCentroid(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, test_data, offset, expected):
inst_bbox = get_bbox(test_data[None])
diff --git a/tests/test_generate_instance_centroidd.py b/tests/test_generate_instance_centroidd.py
index 92e45cdf84..d381ad8c0e 100644
--- a/tests/test_generate_instance_centroidd.py
+++ b/tests/test_generate_instance_centroidd.py
@@ -41,6 +41,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestGenerateInstanceCentroidd(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, test_data, offset, expected):
inst_bbox = get_bbox(test_data[None])
diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py
index 9058855e62..7f4290747d 100644
--- a/tests/test_generate_instance_contour.py
+++ b/tests/test_generate_instance_contour.py
@@ -46,6 +46,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestGenerateInstanceContour(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, test_data, min_num_points, offset, expected):
inst_bbox = get_bbox(test_data[None])
diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py
index 22e3669850..5c831ee680 100644
--- a/tests/test_generate_instance_contourd.py
+++ b/tests/test_generate_instance_contourd.py
@@ -46,6 +46,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestGenerateInstanceContourd(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, test_data, min_num_points, offset, expected):
inst_bbox = get_bbox(test_data[None])
diff --git a/tests/test_generate_instance_type.py b/tests/test_generate_instance_type.py
index 354f8640ae..24e1d1b6d0 100644
--- a/tests/test_generate_instance_type.py
+++ b/tests/test_generate_instance_type.py
@@ -41,6 +41,7 @@
class TestGenerateInstanceType(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, type_pred, seg_pred, bbox, expected):
result = GenerateInstanceType()(in_type(type_pred[None]), in_type(seg_pred[None]), bbox, 1)
diff --git a/tests/test_generate_instance_typed.py b/tests/test_generate_instance_typed.py
index 84a5344503..958f68d6bb 100644
--- a/tests/test_generate_instance_typed.py
+++ b/tests/test_generate_instance_typed.py
@@ -41,6 +41,7 @@
class TestGenerateInstanceTyped(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, type_pred, seg_pred, bbox, expected):
test_data = {"type_pred": in_type(type_pred[None]), "seg": in_type(seg_pred[None]), "bbox": bbox, "id": 1}
diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py
index c276171bd5..1cbb5f05c3 100644
--- a/tests/test_generate_label_classes_crop_centers.py
+++ b/tests/test_generate_label_classes_crop_centers.py
@@ -48,6 +48,7 @@
class TestGenerateLabelClassesCropCenters(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_type_shape(self, input_data, expected_type, expected_count, expected_shape):
results = []
diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py
index 8301e40188..a78dba9f03 100644
--- a/tests/test_generate_param_groups.py
+++ b/tests/test_generate_param_groups.py
@@ -68,6 +68,7 @@
class TestGenerateParamGroups(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_lr_values(self, input_param, expected_values, expected_groups):
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py
index 13b7b728b4..de127b33df 100644
--- a/tests/test_generate_pos_neg_label_crop_centers.py
+++ b/tests/test_generate_pos_neg_label_crop_centers.py
@@ -51,6 +51,7 @@
class TestGeneratePosNegLabelCropCenters(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_type_shape(self, input_data, expected_type, expected_count, expected_shape):
results = []
diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py
index a67e7d0175..6d5b415ec2 100644
--- a/tests/test_generate_spatial_bounding_box.py
+++ b/tests/test_generate_spatial_bounding_box.py
@@ -104,6 +104,7 @@
class TestGenerateSpatialBoundingBox(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_data, expected_box):
result = generate_spatial_bounding_box(**input_data)
diff --git a/tests/test_generate_succinct_contour.py b/tests/test_generate_succinct_contour.py
index 1c60e99546..fc4f5660d9 100644
--- a/tests/test_generate_succinct_contour.py
+++ b/tests/test_generate_succinct_contour.py
@@ -44,6 +44,7 @@
class TestGenerateSuccinctContour(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, test_data, height, width, expected):
result = GenerateSuccinctContour(height=height, width=width)(test_data)
diff --git a/tests/test_generate_succinct_contourd.py b/tests/test_generate_succinct_contourd.py
index e94a02fed5..7b023d8618 100644
--- a/tests/test_generate_succinct_contourd.py
+++ b/tests/test_generate_succinct_contourd.py
@@ -45,6 +45,7 @@
class TestGenerateSuccinctContour(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, data, height, width, expected):
test_data = {"contour": data}
diff --git a/tests/test_generate_watershed_markers.py b/tests/test_generate_watershed_markers.py
index a763361913..238fb00ee0 100644
--- a/tests/test_generate_watershed_markers.py
+++ b/tests/test_generate_watershed_markers.py
@@ -38,6 +38,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
@unittest.skipUnless(has_scipy, "Requires scipy library.")
class TestGenerateWatershedMarkers(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, mask, probmap, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_watershed_markersd.py b/tests/test_generate_watershed_markersd.py
index 76d4ec1ae6..a3c2b9c231 100644
--- a/tests/test_generate_watershed_markersd.py
+++ b/tests/test_generate_watershed_markersd.py
@@ -68,6 +68,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
@unittest.skipUnless(has_scipy, "Requires scipy library.")
class TestGenerateWatershedMarkersd(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, mask, border_map, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_watershed_mask.py b/tests/test_generate_watershed_mask.py
index 1cc35dca5c..5224a912b0 100644
--- a/tests/test_generate_watershed_mask.py
+++ b/tests/test_generate_watershed_mask.py
@@ -58,6 +58,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy library.")
class TestGenerateWatershedMask(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generate_watershed_maskd.py b/tests/test_generate_watershed_maskd.py
index aa6d5bf03a..9d0f2c274a 100644
--- a/tests/test_generate_watershed_maskd.py
+++ b/tests/test_generate_watershed_maskd.py
@@ -58,6 +58,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy library.")
class TestGenerateWatershedMaskd(unittest.TestCase):
+
@parameterized.expand(EXCEPTION_TESTS)
def test_value(self, arguments, exception_type):
with self.assertRaises(exception_type):
diff --git a/tests/test_generator.py b/tests/test_generator.py
index c336acf7ef..f531f928da 100644
--- a/tests/test_generator.py
+++ b/tests/test_generator.py
@@ -42,6 +42,7 @@
class TestGenerator(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_data, expected_shape):
net = Generator(**input_param)
diff --git a/tests/test_get_equivalent_dtype.py b/tests/test_get_equivalent_dtype.py
index 299a3963b7..2b4de1bc2a 100644
--- a/tests/test_get_equivalent_dtype.py
+++ b/tests/test_get_equivalent_dtype.py
@@ -29,6 +29,7 @@
class TestGetEquivalentDtype(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_get_equivalent_dtype(self, im, input_dtype):
out_dtype = get_equivalent_dtype(input_dtype, type(im))
diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py
index 1338ba0e2c..e60715e2fe 100644
--- a/tests/test_get_extreme_points.py
+++ b/tests/test_get_extreme_points.py
@@ -47,6 +47,7 @@
class TestGetExtremePoints(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_type_shape(self, input_data, expected):
result = get_extreme_points(**input_data)
diff --git a/tests/test_get_layers.py b/tests/test_get_layers.py
index ad0be1a5c4..5c020892ed 100644
--- a/tests/test_get_layers.py
+++ b/tests/test_get_layers.py
@@ -37,6 +37,7 @@
class TestGetLayers(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_NORM)
def test_norm_layer(self, input_param, expected):
layer = get_norm_layer(**input_param)
@@ -54,6 +55,7 @@ def test_dropout_layer(self, input_param, expected):
class TestSuggestion(unittest.TestCase):
+
def test_suggested(self):
with self.assertRaisesRegex(ValueError, "did you mean 'GROUP'?"):
get_norm_layer(name="grop", spatial_dims=2)
diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py
index 1881d79602..ab9e69cd31 100644
--- a/tests/test_get_package_version.py
+++ b/tests/test_get_package_version.py
@@ -17,6 +17,7 @@
class TestGetVersion(unittest.TestCase):
+
def test_default(self):
output = get_package_version("42foobarnoexist")
self.assertTrue("UNKNOWN" in output)
diff --git a/tests/test_get_unique_labels.py b/tests/test_get_unique_labels.py
index e550882243..0a88145489 100644
--- a/tests/test_get_unique_labels.py
+++ b/tests/test_get_unique_labels.py
@@ -35,6 +35,7 @@
class TestGetUniqueLabels(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_correct_results(self, args, expected):
result = get_unique_labels(**args)
diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py
index aad5d6fea6..bdc66b9495 100644
--- a/tests/test_gibbs_noise.py
+++ b/tests/test_gibbs_noise.py
@@ -32,6 +32,7 @@
class TestGibbsNoise(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py
index 3aa69b7280..3b2cae7e84 100644
--- a/tests/test_gibbs_noised.py
+++ b/tests/test_gibbs_noised.py
@@ -33,6 +33,7 @@
class TestGibbsNoised(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_giou_loss.py b/tests/test_giou_loss.py
index e794ddab30..34ee22e0ad 100644
--- a/tests/test_giou_loss.py
+++ b/tests/test_giou_loss.py
@@ -35,6 +35,7 @@
class TestGIoULoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_result(self, input_data, expected_val):
loss = BoxGIoULoss()
diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py
index b67ed71725..36a1978c93 100644
--- a/tests/test_global_mutual_information_loss.py
+++ b/tests/test_global_mutual_information_loss.py
@@ -54,6 +54,7 @@
@skip_if_quick
class TestGlobalMutualInformationLoss(unittest.TestCase):
+
def setUp(self):
config = testing_data_config("images", "Prostate_T2W_AX_1")
download_url_or_skip_test(
@@ -114,6 +115,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
class TestGlobalMutualInformationLossIll(unittest.TestCase):
+
def test_ill_shape(self):
loss = GlobalMutualInformationLoss()
with self.assertRaisesRegex(ValueError, ""):
diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py
index 1ab8db5926..626053377c 100644
--- a/tests/test_globalnet.py
+++ b/tests/test_globalnet.py
@@ -65,6 +65,7 @@
class TestAffineHead(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_AFFINE_TRANSFORM)
def test_shape(self, input_param, theta, expected_val):
layer = AffineHead(**input_param)
@@ -78,6 +79,7 @@ def test_shape(self, input_param, theta, expected_val):
class TestGlobalNet(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_GLOBAL_NET)
def test_shape(self, input_param, input_shape, expected_shape):
net = GlobalNet(**input_param).to(device)
diff --git a/tests/test_gmm.py b/tests/test_gmm.py
index eb638f5479..549e8f1ec4 100644
--- a/tests/test_gmm.py
+++ b/tests/test_gmm.py
@@ -261,6 +261,7 @@
@skip_if_quick
class GMMTestCase(unittest.TestCase):
+
def setUp(self):
self._var = os.environ.get("TORCH_EXTENSIONS_DIR")
self.tempdir = tempfile.mkdtemp()
diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py
index d937a5e266..4a3b4b6340 100644
--- a/tests/test_grid_dataset.py
+++ b/tests/test_grid_dataset.py
@@ -58,6 +58,7 @@ def identity_generator(x):
class TestGridPatchDataset(unittest.TestCase):
+
def setUp(self):
set_determinism(seed=1234)
diff --git a/tests/test_grid_distortion.py b/tests/test_grid_distortion.py
index 1a698140af..9ec85250e8 100644
--- a/tests/test_grid_distortion.py
+++ b/tests/test_grid_distortion.py
@@ -99,6 +99,7 @@
class TestGridDistortion(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_grid_distortion(self, input_param, input_data, expected_val):
g = GridDistortion(**input_param)
diff --git a/tests/test_grid_distortiond.py b/tests/test_grid_distortiond.py
index a645eb4f87..ce73593dc7 100644
--- a/tests/test_grid_distortiond.py
+++ b/tests/test_grid_distortiond.py
@@ -75,6 +75,7 @@
class TestGridDistortiond(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask):
g = GridDistortiond(**input_param)
diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py
index cd1c5b6531..4b324eda1a 100644
--- a/tests/test_grid_patch.py
+++ b/tests/test_grid_patch.py
@@ -97,6 +97,7 @@
class TestGridPatch(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
@SkipIfBeforePyTorchVersion((1, 11, 1))
def test_grid_patch(self, in_type, input_parameters, image, expected):
diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py
index 4f317e4677..53313b3a8f 100644
--- a/tests/test_grid_patchd.py
+++ b/tests/test_grid_patchd.py
@@ -77,6 +77,7 @@
class TestGridPatchd(unittest.TestCase):
+
@parameterized.expand(TEST_SINGLE)
@SkipIfBeforePyTorchVersion((1, 11, 1))
def test_grid_patchd(self, in_type, input_parameters, image_dict, expected):
diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py
index 8877b0c121..f80874d216 100644
--- a/tests/test_grid_pull.py
+++ b/tests/test_grid_pull.py
@@ -63,6 +63,7 @@ def make_grid(shape, dtype=None, device=None, requires_grad=True):
@skip_if_no_cpp_extension
class TestGridPull(unittest.TestCase):
+
@parameterized.expand(TEST_1D_GP, skip_on_empty=True)
def test_grid_pull(self, input_param, expected):
result = grid_pull(**input_param)
diff --git a/tests/test_grid_split.py b/tests/test_grid_split.py
index 3ccf6e75a8..852a4847a6 100644
--- a/tests/test_grid_split.py
+++ b/tests/test_grid_split.py
@@ -66,6 +66,7 @@
class TestGridSplit(unittest.TestCase):
+
@parameterized.expand(TEST_SINGLE)
def test_split_patch_single_call(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
diff --git a/tests/test_grid_splitd.py b/tests/test_grid_splitd.py
index d8519b2121..215076d5a3 100644
--- a/tests/test_grid_splitd.py
+++ b/tests/test_grid_splitd.py
@@ -70,6 +70,7 @@
class TestGridSplitd(unittest.TestCase):
+
@parameterized.expand(TEST_SINGLE)
def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected):
input_dict = {}
diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py
index 7dfb802bba..7b281665b4 100644
--- a/tests/test_handler_checkpoint_loader.py
+++ b/tests/test_handler_checkpoint_loader.py
@@ -23,6 +23,7 @@
class TestHandlerCheckpointLoader(unittest.TestCase):
+
def test_one_save_one_load(self):
net1 = torch.nn.PReLU()
data1 = net1.state_dict()
diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py
index 70810e018f..42f99e57c9 100644
--- a/tests/test_handler_checkpoint_saver.py
+++ b/tests/test_handler_checkpoint_saver.py
@@ -111,6 +111,7 @@
class TestHandlerCheckpointSaver(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
)
diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py
index 905e326a66..5330e48dda 100644
--- a/tests/test_handler_classification_saver.py
+++ b/tests/test_handler_classification_saver.py
@@ -26,6 +26,7 @@
class TestHandlerClassificationSaver(unittest.TestCase):
+
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py
index ef06b69683..47dca2d999 100644
--- a/tests/test_handler_classification_saver_dist.py
+++ b/tests/test_handler_classification_saver_dist.py
@@ -27,6 +27,7 @@
class DistributedHandlerClassificationSaver(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_handler_clearml_image.py b/tests/test_handler_clearml_image.py
index 13eebed120..91aa297b7f 100644
--- a/tests/test_handler_clearml_image.py
+++ b/tests/test_handler_clearml_image.py
@@ -29,6 +29,7 @@
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
@unittest.skipIf(not has_get_active_config_file, "ClearML 'get_active_config_file' not found")
class TestHandlerClearMLImageHandler(unittest.TestCase):
+
def test_task_init(self):
handle, path = tempfile.mkstemp()
with open(handle, "w") as new_config:
diff --git a/tests/test_handler_clearml_stats.py b/tests/test_handler_clearml_stats.py
index a460bc2391..159f6af4eb 100644
--- a/tests/test_handler_clearml_stats.py
+++ b/tests/test_handler_clearml_stats.py
@@ -29,6 +29,7 @@
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
@unittest.skipIf(not has_get_active_config_file, "ClearML 'get_active_config_file' not found")
class TestHandlerClearMLStatsHandler(unittest.TestCase):
+
def test_task_init(self):
handle, path = tempfile.mkstemp()
with open(handle, "w") as new_config:
diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py
index b74b7e57c4..dd30f04142 100644
--- a/tests/test_handler_confusion_matrix_dist.py
+++ b/tests/test_handler_confusion_matrix_dist.py
@@ -23,6 +23,7 @@
class DistributedConfusionMatrix(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute(self):
self._compute()
diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py
index 5bc5584515..37ca7f6870 100644
--- a/tests/test_handler_decollate_batch.py
+++ b/tests/test_handler_decollate_batch.py
@@ -22,6 +22,7 @@
class TestHandlerDecollateBatch(unittest.TestCase):
+
def test_compute(self):
data = [
{"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]},
diff --git a/tests/test_handler_early_stop.py b/tests/test_handler_early_stop.py
index 675a804472..5fbb828330 100644
--- a/tests/test_handler_early_stop.py
+++ b/tests/test_handler_early_stop.py
@@ -19,7 +19,9 @@
class TestHandlerEarlyStop(unittest.TestCase):
+
def test_early_stop_train_loss(self):
+
def _train_func(engine, batch):
return {"loss": 1.5}
@@ -33,6 +35,7 @@ def _train_func(engine, batch):
self.assertEqual(trainer.state.epoch, 2)
def test_early_stop_val_metric(self):
+
def _train_func(engine, batch):
pass
diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py
index f64039b6fb..317eba1b11 100644
--- a/tests/test_handler_garbage_collector.py
+++ b/tests/test_handler_garbage_collector.py
@@ -34,6 +34,7 @@
class TestHandlerGarbageCollector(unittest.TestCase):
+
@skipUnless(has_ignite, "Requires ignite")
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
def test_content(self, data, trigger_event):
diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py
index dbdc765b45..28e0b69621 100644
--- a/tests/test_handler_ignite_metric.py
+++ b/tests/test_handler_ignite_metric.py
@@ -99,6 +99,7 @@
class TestHandlerIgniteMetricHandler(unittest.TestCase):
+
@SkipIfNoModule("ignite")
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg):
diff --git a/tests/test_handler_logfile.py b/tests/test_handler_logfile.py
index f09876ab0a..457aca2ebc 100644
--- a/tests/test_handler_logfile.py
+++ b/tests/test_handler_logfile.py
@@ -30,6 +30,7 @@
class TestHandlerLogfile(unittest.TestCase):
+
def setUp(self):
if has_ignite:
# set up engine
diff --git a/tests/test_handler_lr_scheduler.py b/tests/test_handler_lr_scheduler.py
index f1d3f45f06..3efb4a789f 100644
--- a/tests/test_handler_lr_scheduler.py
+++ b/tests/test_handler_lr_scheduler.py
@@ -25,6 +25,7 @@
class TestHandlerLrSchedule(unittest.TestCase):
+
def test_content(self):
data = [0] * 8
test_lr = 0.1
diff --git a/tests/test_handler_metric_logger.py b/tests/test_handler_metric_logger.py
index 016af1e8b5..06d50e97ff 100644
--- a/tests/test_handler_metric_logger.py
+++ b/tests/test_handler_metric_logger.py
@@ -28,6 +28,7 @@
class TestHandlerMetricLogger(unittest.TestCase):
+
@SkipIfNoModule("ignite")
def test_metric_logging(self):
dummy_name = "dummy"
diff --git a/tests/test_handler_metrics_reloaded.py b/tests/test_handler_metrics_reloaded.py
index e080204d6f..b8fb39d2e8 100644
--- a/tests/test_handler_metrics_reloaded.py
+++ b/tests/test_handler_metrics_reloaded.py
@@ -73,6 +73,7 @@
@unittest.skipIf(not has_metrics, "MetricsReloaded not available.")
class TestHandlerMetricsReloadedBinary(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3])
def test_compute(self, input_params, y_pred, y, expected_value):
input_params["output_transform"] = from_engine(["pred", "label"])
@@ -113,6 +114,7 @@ def test_shape_mismatch(self, input_params, _y_pred, _y, _expected_value):
@unittest.skipIf(not has_metrics, "MetricsReloaded not available.")
class TestMetricsReloadedCategorical(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2])
def test_compute(self, input_params, y_pred, y, expected_value):
input_params["output_transform"] = from_engine(["pred", "label"])
diff --git a/tests/test_handler_metrics_saver.py b/tests/test_handler_metrics_saver.py
index 9888a73e5f..d5ad2f4841 100644
--- a/tests/test_handler_metrics_saver.py
+++ b/tests/test_handler_metrics_saver.py
@@ -24,6 +24,7 @@
class TestHandlerMetricsSaver(unittest.TestCase):
+
def test_content(self):
with tempfile.TemporaryDirectory() as tempdir:
metrics_saver = MetricsSaver(
diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py
index 11d7db168b..46c9ad27d7 100644
--- a/tests/test_handler_metrics_saver_dist.py
+++ b/tests/test_handler_metrics_saver_dist.py
@@ -27,6 +27,7 @@
class DistributedMetricsSaver(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_content(self):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py
index 92cf17eadb..44adc49fc2 100644
--- a/tests/test_handler_mlflow.py
+++ b/tests/test_handler_mlflow.py
@@ -33,6 +33,7 @@
def get_event_filter(e):
+
def event_filter(_, event):
if event in e:
return True
@@ -65,6 +66,7 @@ def _train_func(engine, batch):
class TestHandlerMLFlow(unittest.TestCase):
+
def setUp(self):
self.tmpdir_list = []
diff --git a/tests/test_handler_nvtx.py b/tests/test_handler_nvtx.py
index 75cc5bc1f4..a0d1cdb4d5 100644
--- a/tests/test_handler_nvtx.py
+++ b/tests/test_handler_nvtx.py
@@ -36,6 +36,7 @@
class TestHandlerDecollateBatch(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1])
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
def test_compute(self, data, expected):
diff --git a/tests/test_handler_panoptic_quality.py b/tests/test_handler_panoptic_quality.py
index 1595b5ad2c..337f9c7b49 100644
--- a/tests/test_handler_panoptic_quality.py
+++ b/tests/test_handler_panoptic_quality.py
@@ -60,6 +60,7 @@
@SkipIfNoModule("scipy.optimize")
class TestHandlerPanopticQuality(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_compute(self, input_params, expected_avg):
metric = PanopticQuality(**input_params)
diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py
index 1e7bbb7588..0bcc794381 100644
--- a/tests/test_handler_parameter_scheduler.py
+++ b/tests/test_handler_parameter_scheduler.py
@@ -21,6 +21,7 @@
class ToyNet(Module):
+
def __init__(self, value):
super().__init__()
self.value = value
@@ -36,6 +37,7 @@ def set_value(self, value):
class TestHandlerParameterScheduler(unittest.TestCase):
+
def test_linear_scheduler(self):
# Testing step_constant
net = ToyNet(value=-1)
@@ -116,6 +118,7 @@ def test_multistep_scheduler(self):
assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
def test_custom_scheduler(self):
+
def custom_logic(initial_value, gamma, current_step):
return initial_value * gamma ** (current_step % 9)
diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py
index c449665c1e..0dd518325b 100644
--- a/tests/test_handler_post_processing.py
+++ b/tests/test_handler_post_processing.py
@@ -40,6 +40,7 @@
class TestHandlerPostProcessing(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_compute(self, input_params, decollate, expected):
data = [
diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py
index 153a00b1ac..347f8cb92c 100644
--- a/tests/test_handler_prob_map_producer.py
+++ b/tests/test_handler_prob_map_producer.py
@@ -30,6 +30,7 @@
class TestDataset(Dataset):
+
def __init__(self, name, size):
super().__init__(
data=[
@@ -63,11 +64,13 @@ def __getitem__(self, index):
class TestEvaluator(Evaluator):
+
def _iteration(self, engine, batchdata):
return batchdata
class TestHandlerProbMapGenerator(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
def test_prob_map_generator(self, name, size):
# set up dataset
diff --git a/tests/test_handler_regression_metrics.py b/tests/test_handler_regression_metrics.py
index a06452c54d..a3ec9f071a 100644
--- a/tests/test_handler_regression_metrics.py
+++ b/tests/test_handler_regression_metrics.py
@@ -46,6 +46,7 @@ def psnrmetric_np(max_val, y_pred, y):
class TestHandlerRegressionMetrics(unittest.TestCase):
+
def test_compute(self):
set_determinism(seed=123)
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_handler_regression_metrics_dist.py b/tests/test_handler_regression_metrics_dist.py
index a2e96b97d9..f57db429e8 100644
--- a/tests/test_handler_regression_metrics_dist.py
+++ b/tests/test_handler_regression_metrics_dist.py
@@ -57,6 +57,7 @@ def psnrmetric_np(max_val, y_pred, y):
class DistributedMeanSquaredError(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute(self):
set_determinism(123)
@@ -103,6 +104,7 @@ def _val_func(engine, batch):
class DistributedMeanAbsoluteError(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute(self):
set_determinism(123)
@@ -149,6 +151,7 @@ def _val_func(engine, batch):
class DistributedRootMeanSquaredError(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute(self):
set_determinism(123)
@@ -195,6 +198,7 @@ def _val_func(engine, batch):
class DistributedPeakSignalToNoiseRatio(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute(self):
set_determinism(123)
diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py
index ce2351a9f5..2c771340f9 100644
--- a/tests/test_handler_rocauc.py
+++ b/tests/test_handler_rocauc.py
@@ -21,6 +21,7 @@
class TestHandlerROCAUC(unittest.TestCase):
+
def test_compute(self):
auc_metric = ROCAUC()
act = Activations(softmax=True)
diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py
index 5b6ea045c7..6088251b11 100644
--- a/tests/test_handler_rocauc_dist.py
+++ b/tests/test_handler_rocauc_dist.py
@@ -23,6 +23,7 @@
class DistributedROCAUC(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
def test_compute(self):
auc_metric = ROCAUC()
diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py
index c3b4d72cb4..e544d39c72 100644
--- a/tests/test_handler_smartcache.py
+++ b/tests/test_handler_smartcache.py
@@ -22,6 +22,7 @@
class TestHandlerSmartCache(unittest.TestCase):
+
def test_content(self):
data = [0, 1, 2, 3, 4, 5, 6, 7, 8]
expected = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]]
diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py
index 1842e08635..f876cff2a3 100644
--- a/tests/test_handler_stats.py
+++ b/tests/test_handler_stats.py
@@ -26,6 +26,7 @@
def get_event_filter(e):
+
def event_filter(_, event):
if event in e:
return True
@@ -35,6 +36,7 @@ def event_filter(_, event):
class TestHandlerStats(unittest.TestCase):
+
@parameterized.expand([[True], [get_event_filter([1, 2])]])
def test_metrics_print(self, epoch_log):
log_stream = StringIO()
diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py
index 68b71ff7f9..197b175278 100644
--- a/tests/test_handler_tb_image.py
+++ b/tests/test_handler_tb_image.py
@@ -33,6 +33,7 @@
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
@SkipIfBeforePyTorchVersion((1, 13)) # issue 6683
class TestHandlerTBImage(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_tb_image_shape(self, shape):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py
index 883827a1ac..b96bea13a1 100644
--- a/tests/test_handler_tb_stats.py
+++ b/tests/test_handler_tb_stats.py
@@ -26,6 +26,7 @@
def get_event_filter(e):
+
def event_filter(_, event):
if event in e:
return True
@@ -36,6 +37,7 @@ def event_filter(_, event):
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
class TestHandlerTBStats(unittest.TestCase):
+
def test_metrics_print(self):
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py
index e1ccba2294..752b1d3df7 100644
--- a/tests/test_handler_validation.py
+++ b/tests/test_handler_validation.py
@@ -22,12 +22,14 @@
class TestEvaluator(Evaluator):
+
def _iteration(self, engine, batchdata):
engine.state.output = "called"
return engine.state.output
class TestHandlerValidation(unittest.TestCase):
+
def test_content(self):
data = [0] * 8
diff --git a/tests/test_hardnegsampler.py b/tests/test_hardnegsampler.py
index b33cea1537..5385abd1db 100644
--- a/tests/test_hardnegsampler.py
+++ b/tests/test_hardnegsampler.py
@@ -37,6 +37,7 @@
class TestSampleSlices(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_shape(self, target_label0, target_label1, concat_fg_probs, expected_result_pos, expected_result_neg):
compute_dtypes = [torch.float16, torch.float32]
diff --git a/tests/test_hashing.py b/tests/test_hashing.py
index 093de47cf9..61b3e7056b 100644
--- a/tests/test_hashing.py
+++ b/tests/test_hashing.py
@@ -20,6 +20,7 @@
class TestPickleHashing(unittest.TestCase):
+
def test_pickle(self):
set_determinism(0)
data1 = np.random.rand(10)
@@ -45,6 +46,7 @@ def test_pickle(self):
class TestJSONHashing(unittest.TestCase):
+
def test_json(self):
data_dict1 = {"b": "str2", "a": "str1"}
data_dict2 = {"a": "str1", "b": "str2"}
diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py
index 71bbad36d2..20276a1832 100644
--- a/tests/test_hausdorff_distance.py
+++ b/tests/test_hausdorff_distance.py
@@ -168,6 +168,7 @@ def _describe_test_case(test_func, test_number, params):
class TestHausdorffDistance(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_EXPANDED, doc_func=_describe_test_case)
def test_value(self, device, metric, directed, input_data, expected_value):
percentile = None
diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py
index 5ed20f5f3b..f279d45b14 100644
--- a/tests/test_hausdorff_loss.py
+++ b/tests/test_hausdorff_loss.py
@@ -198,6 +198,7 @@ def _describe_test_case(test_func, test_number, params):
@skipUnless(has_scipy, "Scipy required")
class TestHausdorffDTLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES, doc_func=_describe_test_case)
def test_shape(self, input_param, input_data, expected_val):
result = HausdorffDTLoss(**input_param).forward(**input_data)
@@ -234,6 +235,7 @@ def test_input_warnings(self):
@skipUnless(has_scipy, "Scipy required")
class TesLogtHausdorffDTLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_LOG, doc_func=_describe_test_case)
def test_shape(self, input_param, input_data, expected_val):
result = LogHausdorffDTLoss(**input_param).forward(**input_data)
diff --git a/tests/test_header_correct.py b/tests/test_header_correct.py
index 71fed1e35d..c0ea2a8643 100644
--- a/tests/test_header_correct.py
+++ b/tests/test_header_correct.py
@@ -20,6 +20,7 @@
class TestCorrection(unittest.TestCase):
+
def test_correct(self):
test_img = nib.Nifti1Image(np.zeros((1, 2, 3)), np.eye(4))
test_img.header.set_zooms((100, 100, 100))
diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py
index 04520419b7..bcc5739900 100644
--- a/tests/test_highresnet.py
+++ b/tests/test_highresnet.py
@@ -48,6 +48,7 @@
class TestHighResNet(DistTestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_param, input_shape, expected_shape):
net = HighResNet(**input_param).to(device)
diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py
index 68fa0b1192..879a74969d 100644
--- a/tests/test_hilbert_transform.py
+++ b/tests/test_hilbert_transform.py
@@ -161,6 +161,7 @@ def create_expected_numpy_output(input_datum, **kwargs):
@SkipIfNoModule("torch.fft")
class TestHilbertTransformCPU(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_1D_SINE_CPU,
@@ -179,6 +180,7 @@ def test_value(self, arguments, image, expected_data, atol):
@skip_if_no_cuda
@SkipIfNoModule("torch.fft")
class TestHilbertTransformGPU(unittest.TestCase):
+
@parameterized.expand(
(
[]
@@ -201,6 +203,7 @@ def test_value(self, arguments, image, expected_data, atol):
@SkipIfModule("torch.fft")
class TestHilbertTransformNoFFTMod(unittest.TestCase):
+
def test_no_fft_module_error(self):
self.assertRaises(OptionalImportError, HilbertTransform(), torch.randn(1, 1, 10))
diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py
index 3a340db52a..25c0afb64d 100644
--- a/tests/test_histogram_normalize.py
+++ b/tests/test_histogram_normalize.py
@@ -48,6 +48,7 @@
class TestHistogramNormalize(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = HistogramNormalize(**arguments)(image)
diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py
index 24f27d225e..a390375441 100644
--- a/tests/test_histogram_normalized.py
+++ b/tests/test_histogram_normalized.py
@@ -48,6 +48,7 @@
class TestHistogramNormalized(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = HistogramNormalized(**arguments)(image)["img"]
diff --git a/tests/test_hovernet.py b/tests/test_hovernet.py
index d768895bdc..fb4946b011 100644
--- a/tests/test_hovernet.py
+++ b/tests/test_hovernet.py
@@ -154,6 +154,7 @@ def check_kernels(net, mode):
class TestHoverNet(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shapes):
input_param["decoder_padding"] = False
diff --git a/tests/test_hovernet_instance_map_post_processing.py b/tests/test_hovernet_instance_map_post_processing.py
index 990e2d9a10..ce272fba1a 100644
--- a/tests/test_hovernet_instance_map_post_processing.py
+++ b/tests/test_hovernet_instance_map_post_processing.py
@@ -42,6 +42,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy library.")
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestHoVerNetInstanceMapPostProcessing(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_value(self, in_type, test_data, kwargs, expected_info, expected_map):
nuclear_prediction = in_type(test_data.astype(float))
diff --git a/tests/test_hovernet_instance_map_post_processingd.py b/tests/test_hovernet_instance_map_post_processingd.py
index 69e42d3495..c982156caa 100644
--- a/tests/test_hovernet_instance_map_post_processingd.py
+++ b/tests/test_hovernet_instance_map_post_processingd.py
@@ -43,6 +43,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy library.")
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestHoVerNetInstanceMapPostProcessingd(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_value(self, in_type, test_data, kwargs, expected_info, expected_map):
input = {
diff --git a/tests/test_hovernet_loss.py b/tests/test_hovernet_loss.py
index 10db4518fa..b7cd1f3104 100644
--- a/tests/test_hovernet_loss.py
+++ b/tests/test_hovernet_loss.py
@@ -35,6 +35,7 @@
class PrepareTestInputs:
+
def __init__(self, inputs):
self.inputs = {HoVerNetBranch.NP: inputs[1], HoVerNetBranch.HV: inputs[3]}
self.targets = {HoVerNetBranch.NP: inputs[0], HoVerNetBranch.HV: inputs[2]}
@@ -171,6 +172,7 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w
class TestHoverNetLoss(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, expected_loss):
loss = HoVerNetLoss()
diff --git a/tests/test_hovernet_nuclear_type_post_processing.py b/tests/test_hovernet_nuclear_type_post_processing.py
index f2b33c96ae..e97b7abd2c 100644
--- a/tests/test_hovernet_nuclear_type_post_processing.py
+++ b/tests/test_hovernet_nuclear_type_post_processing.py
@@ -41,6 +41,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy library.")
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestHoVerNetNuclearTypePostProcessing(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_value(self, in_type, test_data, kwargs, expected_info, expected_map):
nuclear_prediction = in_type(test_data.astype(float))
diff --git a/tests/test_hovernet_nuclear_type_post_processingd.py b/tests/test_hovernet_nuclear_type_post_processingd.py
index 01478b7961..26cf80592c 100644
--- a/tests/test_hovernet_nuclear_type_post_processingd.py
+++ b/tests/test_hovernet_nuclear_type_post_processingd.py
@@ -42,6 +42,7 @@
@unittest.skipUnless(has_scipy, "Requires scipy library.")
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestHoVerNetNuclearTypePostProcessingd(unittest.TestCase):
+
@parameterized.expand(TEST_CASE)
def test_value(self, in_type, test_data, kwargs, expected):
input = {
diff --git a/tests/test_identity.py b/tests/test_identity.py
index 19116cbb8f..4243a7f19a 100644
--- a/tests/test_identity.py
+++ b/tests/test_identity.py
@@ -18,6 +18,7 @@
class TestIdentity(NumpyImageTestCase2D):
+
def test_identity(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
diff --git a/tests/test_identityd.py b/tests/test_identityd.py
index 98499def01..6b81ad9f16 100644
--- a/tests/test_identityd.py
+++ b/tests/test_identityd.py
@@ -18,6 +18,7 @@
class TestIdentityd(NumpyImageTestCase2D):
+
def test_identityd(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py
index 7f7bdec513..fc8b4b6ccb 100644
--- a/tests/test_image_dataset.py
+++ b/tests/test_image_dataset.py
@@ -47,6 +47,7 @@ def __call__(self, data):
class _TestCompose(Compose):
+
def __call__(self, data, meta, lazy):
data = self.transforms[0](data) # ensure channel first
data = self.transforms[1](data, lazy=lazy) # spacing
@@ -57,6 +58,7 @@ def __call__(self, data, meta, lazy):
class TestImageDataset(unittest.TestCase):
+
def test_use_case(self):
with tempfile.TemporaryDirectory() as tempdir:
img_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)).astype(float), np.eye(4))
diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py
index 985ea95e79..adc9dade9c 100644
--- a/tests/test_image_filter.py
+++ b/tests/test_image_filter.py
@@ -38,6 +38,7 @@
class TestModule(torch.nn.Module):
+
def __init__(self):
super().__init__()
@@ -50,6 +51,7 @@ class TestNotAModuleOrTransform:
class TestImageFilter(unittest.TestCase):
+
@parameterized.expand(SUPPORTED_FILTERS)
def test_init_from_string(self, filter_name):
"Test init from string"
@@ -133,6 +135,7 @@ def test_pass_empty_metadata_dict(self):
class TestImageFilterDict(unittest.TestCase):
+
@parameterized.expand(SUPPORTED_FILTERS)
def test_init_from_string_dict(self, filter_name):
"Test init from string and assert an error is thrown if no size is passed"
@@ -162,6 +165,7 @@ def test_call_3d(self, filter_name):
class TestRandImageFilter(unittest.TestCase):
+
@parameterized.expand(SUPPORTED_FILTERS)
def test_init_from_string(self, filter_name):
"Test init from string and assert an error is thrown if no size is passed"
@@ -205,6 +209,7 @@ def test_call_3d_prob_0(self, filter_name):
class TestRandImageFilterDict(unittest.TestCase):
+
@parameterized.expand(SUPPORTED_FILTERS)
def test_init_from_string_dict(self, filter_name):
"Test init from string and assert an error is thrown if no size is passed"
diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py
index 79e51c53eb..7e1c1deecc 100644
--- a/tests/test_image_rw.py
+++ b/tests/test_image_rw.py
@@ -33,6 +33,7 @@
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSaveNifti(unittest.TestCase):
+
def setUp(self):
self.test_dir = tempfile.mkdtemp()
@@ -97,6 +98,7 @@ def test_4d(self, reader, writer):
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSavePNG(unittest.TestCase):
+
def setUp(self):
self.test_dir = tempfile.mkdtemp()
@@ -137,6 +139,7 @@ def test_rgb(self, reader, writer):
class TestRegRes(unittest.TestCase):
+
def test_0_default(self):
self.assertTrue(len(resolve_writer(".png")) > 0, "has png writer")
self.assertTrue(len(resolve_writer(".nrrd")) > 0, "has nrrd writer")
@@ -153,6 +156,7 @@ def test_1_new(self):
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSaveNrrd(unittest.TestCase):
+
def setUp(self):
self.test_dir = tempfile.mkdtemp()
diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py
index 7825f9b4d7..901ca77e7f 100644
--- a/tests/test_img2tensorboard.py
+++ b/tests/test_img2tensorboard.py
@@ -21,6 +21,7 @@
class TestImg2Tensorboard(unittest.TestCase):
+
def test_write_gray(self):
nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32)
summary_object_np = make_animated_gif_summary(
diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py
index 1350146220..cb45cb5146 100644
--- a/tests/test_init_reader.py
+++ b/tests/test_init_reader.py
@@ -19,6 +19,7 @@
class TestInitLoadImage(unittest.TestCase):
+
def test_load_image(self):
instance1 = LoadImage(image_only=False, dtype=None)
instance2 = LoadImage(image_only=True, dtype=None)
diff --git a/tests/test_integration_autorunner.py b/tests/test_integration_autorunner.py
index 7110db568d..31a0813abc 100644
--- a/tests/test_integration_autorunner.py
+++ b/tests/test_integration_autorunner.py
@@ -71,6 +71,7 @@
@SkipIfBeforePyTorchVersion((1, 11, 1)) # for mem_get_info
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestAutoRunner(unittest.TestCase):
+
def setUp(self) -> None:
self.test_dir = tempfile.TemporaryDirectory()
test_path = self.test_dir.name
diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py
index bd96f50c55..c2e0fb55b7 100644
--- a/tests/test_integration_bundle_run.py
+++ b/tests/test_integration_bundle_run.py
@@ -37,6 +37,7 @@
class _Runnable42:
+
def __init__(self, val=1):
self.val = val
@@ -46,6 +47,7 @@ def run(self):
class _Runnable43:
+
def __init__(self, func):
self.func = func
@@ -54,6 +56,7 @@ def run(self):
class TestBundleRun(unittest.TestCase):
+
def setUp(self):
self.data_dir = tempfile.mkdtemp()
diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py
index 4fc92c4068..b137fc9b75 100644
--- a/tests/test_integration_classification_2d.py
+++ b/tests/test_integration_classification_2d.py
@@ -45,6 +45,7 @@
class MedNISTDataset(torch.utils.data.Dataset):
+
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
@@ -182,6 +183,7 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10
@skip_if_quick
class IntegrationClassification2D(DistTestCase):
+
def setUp(self):
set_determinism(seed=0)
self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py
index 6821279080..3e88f05620 100644
--- a/tests/test_integration_determinism.py
+++ b/tests/test_integration_determinism.py
@@ -26,7 +26,9 @@
def run_test(batch_size=64, train_steps=200, device="cuda:0"):
+
class _TestBatch(Dataset):
+
def __init__(self, transforms):
self.transforms = transforms
@@ -76,6 +78,7 @@ def __len__(self):
class TestDeterminism(DistTestCase):
+
def setUp(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py
index 497fe22dab..071eb5cf78 100644
--- a/tests/test_integration_fast_train.py
+++ b/tests/test_integration_fast_train.py
@@ -58,6 +58,7 @@
@skip_if_no_cuda
@skip_if_quick
class IntegrationFastTrain(DistTestCase):
+
def setUp(self):
set_determinism(seed=0)
monai.config.print_config()
diff --git a/tests/test_integration_gpu_customization.py b/tests/test_integration_gpu_customization.py
index 44165b967c..043405a580 100644
--- a/tests/test_integration_gpu_customization.py
+++ b/tests/test_integration_gpu_customization.py
@@ -70,6 +70,7 @@
@SkipIfBeforePyTorchVersion((1, 11, 1)) # module 'torch.cuda' has no attribute 'mem_get_info'
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestEnsembleGpuCustomization(unittest.TestCase):
+
def setUp(self) -> None:
self.test_dir = tempfile.TemporaryDirectory()
diff --git a/tests/test_integration_lazy_samples.py b/tests/test_integration_lazy_samples.py
index c365616bc8..51d80e7305 100644
--- a/tests/test_integration_lazy_samples.py
+++ b/tests/test_integration_lazy_samples.py
@@ -160,6 +160,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None,
@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 11))
class IntegrationLazyResampling(DistTestCase):
+
def setUp(self):
monai.config.print_config()
set_determinism(seed=0)
diff --git a/tests/test_integration_nnunetv2_runner.py b/tests/test_integration_nnunetv2_runner.py
index d35737f86f..822d454f52 100644
--- a/tests/test_integration_nnunetv2_runner.py
+++ b/tests/test_integration_nnunetv2_runner.py
@@ -49,6 +49,7 @@
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
@unittest.skipIf(not has_nnunet, "no nnunetv2")
class TestnnUNetV2Runner(unittest.TestCase):
+
def setUp(self) -> None:
self.test_dir = tempfile.TemporaryDirectory()
test_path = self.test_dir.name
diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py
index 2e4cc31645..c72369b151 100644
--- a/tests/test_integration_segmentation_3d.py
+++ b/tests/test_integration_segmentation_3d.py
@@ -235,6 +235,7 @@ def run_inference_test(root_dir, device="cuda:0"):
@skip_if_quick
class IntegrationSegmentation3D(DistTestCase):
+
def setUp(self):
set_determinism(seed=0)
diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py
index bcc66a687e..8b53e94941 100644
--- a/tests/test_integration_sliding_window.py
+++ b/tests/test_integration_sliding_window.py
@@ -72,6 +72,7 @@ def save_func(engine):
@skip_if_quick
class TestIntegrationSlidingWindow(DistTestCase):
+
def setUp(self):
set_determinism(seed=0)
diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py
index c858060c31..750a20ea5c 100644
--- a/tests/test_integration_stn.py
+++ b/tests/test_integration_stn.py
@@ -98,6 +98,7 @@ def compare_2d(is_ref=True, device=None, reverse_indexing=False):
class TestSpatialTransformerCore(DistTestCase):
+
def setUp(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py
index 90c0098d36..918190775c 100644
--- a/tests/test_integration_unet_2d.py
+++ b/tests/test_integration_unet_2d.py
@@ -25,7 +25,9 @@
def run_test(net_name="basicunet", batch_size=64, train_steps=100, device="cuda:0"):
+
class _TestBatch(Dataset):
+
def __getitem__(self, _unused_id):
im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)
return im[None], seg[None].astype(np.float32)
@@ -54,6 +56,7 @@ def __len__(self):
@skip_if_quick
class TestIntegrationUnet2D(DistTestCase):
+
@TimedCall(seconds=20, daemon=False)
def test_unet_training(self):
for n in ["basicunet", "unet"]:
diff --git a/tests/test_integration_workers.py b/tests/test_integration_workers.py
index 33c26cedf8..123b1ddc6f 100644
--- a/tests/test_integration_workers.py
+++ b/tests/test_integration_workers.py
@@ -44,6 +44,7 @@ def run_loading_test(num_workers=50, device=None, pw=False):
@skip_if_no_cuda
@SkipIfBeforePyTorchVersion((1, 9))
class IntegrationLoading(DistTestCase):
+
def tearDown(self):
set_determinism(seed=None)
diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py
index 7c6f35f3d3..fafb66f675 100644
--- a/tests/test_integration_workflows.py
+++ b/tests/test_integration_workflows.py
@@ -118,6 +118,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
)
class _TestEvalIterEvents:
+
def attach(self, engine):
engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)
@@ -160,6 +161,7 @@ def _forward_completed(self, engine):
)
class _TestTrainIterEvents:
+
def attach(self, engine):
engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)
engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed)
@@ -284,6 +286,7 @@ def save_func(engine):
@skip_if_quick
class IntegrationWorkflows(DistTestCase):
+
def setUp(self):
set_determinism(seed=0)
diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py
index 6896241d35..1428506020 100644
--- a/tests/test_integration_workflows_gan.py
+++ b/tests/test_integration_workflows_gan.py
@@ -127,6 +127,7 @@ def generator_loss(gen_images):
@skip_if_quick
class IntegrationWorkflowsGAN(DistTestCase):
+
def setUp(self):
set_determinism(seed=0)
diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py
index 243fcd0dd4..e45c2acbad 100644
--- a/tests/test_intensity_stats.py
+++ b/tests/test_intensity_stats.py
@@ -53,6 +53,7 @@
class TestIntensityStats(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_param, img, meta_dict, expected):
_, meta_dict = IntensityStats(**input_param)(img, meta_dict)
diff --git a/tests/test_intensity_statsd.py b/tests/test_intensity_statsd.py
index 3fe82b1df7..d164f249db 100644
--- a/tests/test_intensity_statsd.py
+++ b/tests/test_intensity_statsd.py
@@ -52,6 +52,7 @@
class TestIntensityStatsd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, input_param, data, meta_key, expected):
meta = IntensityStatsd(**input_param)(data)[meta_key]
diff --git a/tests/test_inverse_array.py b/tests/test_inverse_array.py
index c0b1a77e55..4da9ee34b9 100644
--- a/tests/test_inverse_array.py
+++ b/tests/test_inverse_array.py
@@ -33,6 +33,7 @@
@unittest.skipUnless(has_nib, "Requires nibabel")
class TestInverseArray(unittest.TestCase):
+
@staticmethod
def get_image(dtype, device) -> MetaTensor:
affine = torch.tensor([[0, 0, 1, 0], [-1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 0, 1]]).to(dtype).to(device)
diff --git a/tests/test_invert.py b/tests/test_invert.py
index 9c57b11331..69d31edfc8 100644
--- a/tests/test_invert.py
+++ b/tests/test_invert.py
@@ -41,6 +41,7 @@
class TestInvert(unittest.TestCase):
+
def test_invert(self):
set_determinism(seed=0)
im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1]) # label image, discrete
diff --git a/tests/test_invertd.py b/tests/test_invertd.py
index 2e6ee35981..c32a3af643 100644
--- a/tests/test_invertd.py
+++ b/tests/test_invertd.py
@@ -43,6 +43,7 @@
class TestInvertd(unittest.TestCase):
+
def test_invert(self):
set_determinism(seed=0)
im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100))
diff --git a/tests/test_is_supported_format.py b/tests/test_is_supported_format.py
index 591772bb3a..fb488eb054 100644
--- a/tests/test_is_supported_format.py
+++ b/tests/test_is_supported_format.py
@@ -33,6 +33,7 @@
class TestIsSupportedFormat(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_value(self, input_param, result):
self.assertEqual(is_supported_format(**input_param), result)
diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py
index 38be9ec30c..cfa711e4c0 100644
--- a/tests/test_iterable_dataset.py
+++ b/tests/test_iterable_dataset.py
@@ -24,6 +24,7 @@
class _Stream:
+
def __init__(self, data):
self.data = data
@@ -32,6 +33,7 @@ def __iter__(self):
class TestIterableDataset(unittest.TestCase):
+
def test_shape(self):
expected_shape = (128, 128, 128)
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
diff --git a/tests/test_itk_torch_bridge.py b/tests/test_itk_torch_bridge.py
index b368230c53..22ae019271 100644
--- a/tests/test_itk_torch_bridge.py
+++ b/tests/test_itk_torch_bridge.py
@@ -49,6 +49,7 @@
@unittest.skipUnless(has_itk, "Requires `itk` package.")
class TestITKTorchAffineMatrixBridge(unittest.TestCase):
+
def setUp(self):
set_determinism(seed=0)
self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
@@ -493,6 +494,7 @@ def test_use_reference_space(self, ref_filepath, filepath):
@unittest.skipUnless(has_nib, "Requires `nibabel` package.")
@skip_if_quick
class TestITKTorchRW(unittest.TestCase):
+
def setUp(self):
TestITKTorchAffineMatrixBridge.setUp(self)
diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py
index c9707b1b5a..6625339dd0 100644
--- a/tests/test_itk_writer.py
+++ b/tests/test_itk_writer.py
@@ -27,6 +27,7 @@
@unittest.skipUnless(has_itk, "Requires `itk` package.")
class TestITKWriter(unittest.TestCase):
+
def test_channel_shape(self):
with tempfile.TemporaryDirectory() as tempdir:
for c in (0, 1, 2, 3):
diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py
index 4d820573a6..17acedf319 100644
--- a/tests/test_k_space_spike_noise.py
+++ b/tests/test_k_space_spike_noise.py
@@ -32,6 +32,7 @@
class TestKSpaceSpikeNoise(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py
index 76a79d4b12..ce542af0aa 100644
--- a/tests/test_k_space_spike_noised.py
+++ b/tests/test_k_space_spike_noised.py
@@ -33,6 +33,7 @@
class TestKSpaceSpikeNoised(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py
index 7da3c4b21f..2dfac1142e 100644
--- a/tests/test_keep_largest_connected_component.py
+++ b/tests/test_keep_largest_connected_component.py
@@ -381,6 +381,7 @@ def to_onehot(x):
class TestKeepLargestConnectedComponent(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_correct_results(self, _, args, input_image, expected):
converter = KeepLargestConnectedComponent(**args)
diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py
index aac91a2de9..4d3172741d 100644
--- a/tests/test_keep_largest_connected_componentd.py
+++ b/tests/test_keep_largest_connected_componentd.py
@@ -337,6 +337,7 @@
class TestKeepLargestConnectedComponentd(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, args, input_dict, expected):
converter = KeepLargestConnectedComponentd(**args)
diff --git a/tests/test_kspace_mask.py b/tests/test_kspace_mask.py
index 5d6d9c18ea..cfbd7864c8 100644
--- a/tests/test_kspace_mask.py
+++ b/tests/test_kspace_mask.py
@@ -26,6 +26,7 @@
class TestMRIUtils(unittest.TestCase):
+
@parameterized.expand(TESTSM)
def test_mask(self, test_data):
# random mask
diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py
index 47a8706491..93cf95a2a0 100644
--- a/tests/test_label_filter.py
+++ b/tests/test_label_filter.py
@@ -58,6 +58,7 @@
class TestLabelFilter(unittest.TestCase):
+
@parameterized.expand(VALID_TESTS)
def test_correct_results(self, _, args, input_image, expected):
converter = LabelFilter(**args)
diff --git a/tests/test_label_filterd.py b/tests/test_label_filterd.py
index f27df08c2a..fba8100f25 100644
--- a/tests/test_label_filterd.py
+++ b/tests/test_label_filterd.py
@@ -58,6 +58,7 @@
class TestLabelFilter(unittest.TestCase):
+
@parameterized.expand(VALID_TESTS)
def test_correct_results(self, _, args, input_image, expected):
converter = LabelFilterd(keys="image", **args)
diff --git a/tests/test_label_quality_score.py b/tests/test_label_quality_score.py
index aa243b4236..a46b78b1d4 100644
--- a/tests/test_label_quality_score.py
+++ b/tests/test_label_quality_score.py
@@ -99,6 +99,7 @@
class TestLabelQualityScore(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value(self, input_data, expected_value):
result = label_quality_score(**input_data)
diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py
index 590fd5d4e4..d7fbfc9b8d 100644
--- a/tests/test_label_to_contour.py
+++ b/tests/test_label_to_contour.py
@@ -142,6 +142,7 @@ def gen_fixed_img(array_type):
class TestContour(unittest.TestCase):
+
def test_contour(self):
input_param = {"kernel_type": "Laplace"}
diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py
index 6fcec72dd8..a91a712da6 100644
--- a/tests/test_label_to_contourd.py
+++ b/tests/test_label_to_contourd.py
@@ -143,6 +143,7 @@ def gen_fixed_img(array_type):
class TestContourd(unittest.TestCase):
+
def test_contour(self):
input_param = {"keys": "img", "kernel_type": "Laplace"}
diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py
index 2eba825cf3..47a58cc989 100644
--- a/tests/test_label_to_mask.py
+++ b/tests/test_label_to_mask.py
@@ -59,6 +59,7 @@
class TestLabelToMask(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = LabelToMask(**arguments)(image)
diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py
index 35f54ca5b9..44b537128d 100644
--- a/tests/test_label_to_maskd.py
+++ b/tests/test_label_to_maskd.py
@@ -59,6 +59,7 @@
class TestLabelToMaskd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, input_data, expected_data):
result = LabelToMaskd(**arguments)(input_data)
diff --git a/tests/test_lambda.py b/tests/test_lambda.py
index e2276d671c..e0a5cf84db 100644
--- a/tests/test_lambda.py
+++ b/tests/test_lambda.py
@@ -23,6 +23,7 @@
class TestLambda(NumpyImageTestCase2D):
+
def test_lambda_identity(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py
index 02e4423b74..fad5ebeee4 100644
--- a/tests/test_lambdad.py
+++ b/tests/test_lambdad.py
@@ -23,6 +23,7 @@
class TestLambdad(NumpyImageTestCase2D):
+
def test_lambdad_identity(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py
index 10682c2bb7..0622809102 100644
--- a/tests/test_lesion_froc.py
+++ b/tests/test_lesion_froc.py
@@ -298,6 +298,7 @@ def prepare_test_data():
class TestEvaluateTumorFROC(unittest.TestCase):
+
@skipUnless(has_cucim, "Requires cucim")
@skipUnless(has_skimage, "Requires skimage")
@skipUnless(has_sp, "Requires scipy")
diff --git a/tests/test_list_data_collate.py b/tests/test_list_data_collate.py
index 9be61e3999..56ee040758 100644
--- a/tests/test_list_data_collate.py
+++ b/tests/test_list_data_collate.py
@@ -37,6 +37,7 @@
class TestListDataCollate(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_type_shape(self, input_data, expected_type, expected_shape):
result = list_data_collate(input_data)
diff --git a/tests/test_list_to_dict.py b/tests/test_list_to_dict.py
index 4e6bb8cdf7..abb61ea182 100644
--- a/tests/test_list_to_dict.py
+++ b/tests/test_list_to_dict.py
@@ -32,6 +32,7 @@
class TestListToDict(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value_shape(self, input, output):
result = list_to_dict(input)
diff --git a/tests/test_lltm.py b/tests/test_lltm.py
index 6ee716e1ef..cc64672e77 100644
--- a/tests/test_lltm.py
+++ b/tests/test_lltm.py
@@ -29,6 +29,7 @@
class TestLLTM(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
@SkipIfNoModule("monai._C")
def test_value(self, input_param, expected_h, expected_c):
diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py
index 155b4eb0fc..9d128dd728 100644
--- a/tests/test_lmdbdataset.py
+++ b/tests/test_lmdbdataset.py
@@ -81,6 +81,7 @@
class _InplaceXform(Transform):
+
def __call__(self, data):
if data:
data[0] = data[0] + np.pi
@@ -91,6 +92,7 @@ def __call__(self, data):
@skip_if_windows
class TestLMDBDataset(unittest.TestCase):
+
def test_cache(self):
"""testing no inplace change to the hashed item"""
items = [[list(range(i))] for i in range(5)]
diff --git a/tests/test_lmdbdataset_dist.py b/tests/test_lmdbdataset_dist.py
index 0b4c7c35fa..1acb89beb3 100644
--- a/tests/test_lmdbdataset_dist.py
+++ b/tests/test_lmdbdataset_dist.py
@@ -23,6 +23,7 @@
class _InplaceXform(Transform):
+
def __call__(self, data):
if data:
data[0] = data[0] + np.pi
@@ -33,6 +34,7 @@ def __call__(self, data):
@skip_if_windows
class TestMPLMDBDataset(DistTestCase):
+
def setUp(self):
self.tempdir = tempfile.mkdtemp()
diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py
index b0e390cd73..7281034498 100644
--- a/tests/test_load_decathlon_datalist.py
+++ b/tests/test_load_decathlon_datalist.py
@@ -21,6 +21,7 @@
class TestLoadDecathlonDatalist(unittest.TestCase):
+
def test_seg_values(self):
with tempfile.TemporaryDirectory() as tempdir:
test_data = {
diff --git a/tests/test_load_image.py b/tests/test_load_image.py
index b6a10bceb4..0207079d7d 100644
--- a/tests/test_load_image.py
+++ b/tests/test_load_image.py
@@ -160,6 +160,7 @@ def get_data(self, _obj):
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImage(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
@@ -379,6 +380,7 @@ def test_channel_dim(self, input_param, filename, expected_shape):
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImageMeta(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py
index 534cbb6618..699ed70059 100644
--- a/tests/test_load_imaged.py
+++ b/tests/test_load_imaged.py
@@ -46,6 +46,7 @@
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImaged(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, input_param, expected_shape):
test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4))
@@ -94,6 +95,7 @@ def test_no_file(self):
@unittest.skipUnless(has_itk, "itk not installed")
class TestConsistency(unittest.TestCase):
+
def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext):
data_dict = {"img": filename}
keys = data_dict.keys()
@@ -155,6 +157,7 @@ def test_png(self):
@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImagedMeta(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py
index a121bdd3cd..63422761ca 100644
--- a/tests/test_load_spacing_orientation.py
+++ b/tests/test_load_spacing_orientation.py
@@ -30,6 +30,7 @@
class TestLoadSpacingOrientation(unittest.TestCase):
+
@staticmethod
def load_image(filename):
data = {"image": filename}
diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py
index 859ee1f8d5..83557d830d 100644
--- a/tests/test_loader_semaphore.py
+++ b/tests/test_loader_semaphore.py
@@ -39,6 +39,7 @@ def _run_test():
class TestImportLock(unittest.TestCase):
+
def test_start(self):
_run_test()
diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py
index 21fe8b973f..35a24cd0ca 100644
--- a/tests/test_local_normalized_cross_correlation_loss.py
+++ b/tests/test_local_normalized_cross_correlation_loss.py
@@ -117,6 +117,7 @@
class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data)
diff --git a/tests/test_localnet.py b/tests/test_localnet.py
index f557147960..97aa94d2c5 100644
--- a/tests/test_localnet.py
+++ b/tests/test_localnet.py
@@ -62,6 +62,7 @@
class TestLocalNet(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = LocalNet(**input_param).to(device)
diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py
index 27ea4cd1a6..340a8e94ba 100644
--- a/tests/test_localnet_block.py
+++ b/tests/test_localnet_block.py
@@ -48,6 +48,7 @@
class TestLocalNetDownSampleBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_DOWN_SAMPLE)
def test_shape(self, input_param):
net = LocalNetDownSampleBlock(**input_param)
@@ -74,6 +75,7 @@ def test_ill_shape(self, input_param):
class TestLocalNetUpSampleBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_UP_SAMPLE)
def test_shape(self, input_param):
net = LocalNetUpSampleBlock(**input_param)
@@ -100,6 +102,7 @@ def test_ill_shape(self, input_param):
class TestExtractBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_EXTRACT)
def test_shape(self, input_param):
net = LocalNetFeatureExtractorBlock(**input_param)
diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py
index 5f81fb8d43..d40b7eaa8c 100644
--- a/tests/test_look_up_option.py
+++ b/tests/test_look_up_option.py
@@ -44,6 +44,7 @@ class _CaseStrEnum(StrEnum):
class TestLookUpOption(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_look_up(self, input_str, supported, expected):
output = look_up_option(input_str, supported)
diff --git a/tests/test_loss_metric.py b/tests/test_loss_metric.py
index 682221f5f5..365dc10670 100644
--- a/tests/test_loss_metric.py
+++ b/tests/test_loss_metric.py
@@ -36,6 +36,7 @@
class TestComputeLossMetric(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_value_class(self, input_data, expected_value):
loss_fn = input_data["loss_class"](**input_data["loss_kwargs"])
diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py
index 46375890eb..d26cb23a90 100644
--- a/tests/test_lr_finder.py
+++ b/tests/test_lr_finder.py
@@ -48,6 +48,7 @@
@unittest.skipUnless(sys.platform == "linux", "requires linux")
@unittest.skipUnless(has_pil, "requires PIL")
class TestLRFinder(unittest.TestCase):
+
def setUp(self):
self.root_dir = MONAIEnvVars.data_dir()
if not self.root_dir:
diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py
index 54092ba931..1a61796fe0 100644
--- a/tests/test_lr_scheduler.py
+++ b/tests/test_lr_scheduler.py
@@ -20,6 +20,7 @@
class SchedulerTestNet(torch.nn.Module):
+
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
@@ -43,6 +44,7 @@ def forward(self, x):
class TestLRSCHEDULER(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_LRSCHEDULER)
def test_shape(self, input_param, expected_lr):
net = SchedulerTestNet()
diff --git a/tests/test_make_nifti.py b/tests/test_make_nifti.py
index 4560507c6c..08d3a731ab 100644
--- a/tests/test_make_nifti.py
+++ b/tests/test_make_nifti.py
@@ -34,6 +34,7 @@
@unittest.skipUnless(has_nib, "Requires nibabel")
class TestMakeNifti(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_make_nifti(self, params):
im, _ = create_test_image_2d(100, 88)
diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py
index 1080c2a513..9931d997bb 100644
--- a/tests/test_map_binary_to_indices.py
+++ b/tests/test_map_binary_to_indices.py
@@ -64,6 +64,7 @@
class TestMapBinaryToIndices(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_type_shape(self, input_data, expected_fg, expected_bg):
fg_indices, bg_indices = map_binary_to_indices(**input_data)
diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py
index 9c8b4b4793..902744ab65 100644
--- a/tests/test_map_classes_to_indices.py
+++ b/tests/test_map_classes_to_indices.py
@@ -124,6 +124,7 @@
class TestMapClassesToIndices(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_data, expected_indices):
indices = map_classes_to_indices(**input_data)
diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py
index 6b8121b6df..cd311df6bd 100644
--- a/tests/test_map_label_value.py
+++ b/tests/test_map_label_value.py
@@ -75,6 +75,7 @@
class TestMapLabelValue(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, input_data, expected_value):
result = MapLabelValue(**input_param)(input_data)
diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py
index fa0d094393..0fb46f2515 100644
--- a/tests/test_map_label_valued.py
+++ b/tests/test_map_label_valued.py
@@ -69,6 +69,7 @@
class TestMapLabelValued(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_5_1, TEST_CASE_6, TEST_CASE_7]
)
diff --git a/tests/test_map_transform.py b/tests/test_map_transform.py
index 7430cf09c7..a7be7b9f5d 100644
--- a/tests/test_map_transform.py
+++ b/tests/test_map_transform.py
@@ -23,11 +23,13 @@
class MapTest(MapTransform):
+
def __call__(self, data):
pass
class TestRandomizable(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_keys(self, keys, expected):
transform = MapTest(keys=keys)
diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py
index 2b831ba415..b7ff324946 100644
--- a/tests/test_mask_intensity.py
+++ b/tests/test_mask_intensity.py
@@ -55,6 +55,7 @@
class TestMaskIntensity(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value(self, arguments, image, expected_data):
for p in TEST_NDARRAYS:
diff --git a/tests/test_mask_intensityd.py b/tests/test_mask_intensityd.py
index 6a39416de4..0efd1f835f 100644
--- a/tests/test_mask_intensityd.py
+++ b/tests/test_mask_intensityd.py
@@ -57,6 +57,7 @@
class TestMaskIntensityd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value(self, arguments, image, expected_data):
result = MaskIntensityd(**arguments)(image)
diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py
index b868f4d3a1..c971723615 100644
--- a/tests/test_masked_dice_loss.py
+++ b/tests/test_masked_dice_loss.py
@@ -113,6 +113,7 @@
class TestDiceLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = MaskedDiceLoss(**input_param).forward(**input_data)
diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py
index 708d507523..3c04ffadcb 100644
--- a/tests/test_masked_loss.py
+++ b/tests/test_masked_loss.py
@@ -40,6 +40,7 @@
class TestMaskedLoss(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
diff --git a/tests/test_masked_patch_wsi_dataset.py b/tests/test_masked_patch_wsi_dataset.py
index 35509b32f6..8d24075595 100644
--- a/tests/test_masked_patch_wsi_dataset.py
+++ b/tests/test_masked_patch_wsi_dataset.py
@@ -74,6 +74,7 @@ def setUpModule():
class MaskedPatchWSIDatasetTests:
+
class Tests(unittest.TestCase):
backend = None
@@ -100,6 +101,7 @@ def test_gen_patches(self, input_parameters, expected):
@skipUnless(has_cucim, "Requires cucim")
class TestSlidingPatchWSIDatasetCuCIM(MaskedPatchWSIDatasetTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "cucim"
@@ -107,6 +109,7 @@ def setUpClass(cls):
@skipUnless(has_osl, "Requires openslide")
class TestSlidingPatchWSIDatasetOpenSlide(MaskedPatchWSIDatasetTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "openslide"
diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py
index a6cb3fcee3..e513025e69 100644
--- a/tests/test_matshow3d.py
+++ b/tests/test_matshow3d.py
@@ -35,6 +35,7 @@
@SkipIfNoModule("matplotlib")
class TestMatshow3d(unittest.TestCase):
+
def test_3d(self):
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
keys = "image"
diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py
index 09b7f94dc4..6b463f8530 100644
--- a/tests/test_mean_ensemble.py
+++ b/tests/test_mean_ensemble.py
@@ -58,6 +58,7 @@
class TestMeanEnsemble(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_param, img, expected_value):
result = MeanEnsemble(**input_param)(img)
diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py
index 01123b0729..795ae47368 100644
--- a/tests/test_mean_ensembled.py
+++ b/tests/test_mean_ensembled.py
@@ -72,6 +72,7 @@
class TestMeanEnsembled(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_param, data, expected_value):
result = MeanEnsembled(**input_param)(data)
diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py
index 9f27adff4c..1f5e623260 100644
--- a/tests/test_median_filter.py
+++ b/tests/test_median_filter.py
@@ -20,6 +20,7 @@
class MedianFilterTestCase(unittest.TestCase):
+
def test_3d_big(self):
a = torch.ones(1, 1, 2, 3, 5)
g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))
diff --git a/tests/test_median_smooth.py b/tests/test_median_smooth.py
index 21cd45f28e..5930c0c6b6 100644
--- a/tests/test_median_smooth.py
+++ b/tests/test_median_smooth.py
@@ -31,6 +31,7 @@
class TestMedianSmooth(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = MedianSmooth(**arguments)(image)
diff --git a/tests/test_median_smoothd.py b/tests/test_median_smoothd.py
index b8d3452c86..e0bdb331c8 100644
--- a/tests/test_median_smoothd.py
+++ b/tests/test_median_smoothd.py
@@ -55,6 +55,7 @@
class TestMedianSmoothd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
result = MedianSmoothd(**arguments)(image)
diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py
index b5a809ccaa..baf3bf4f2d 100644
--- a/tests/test_mednistdataset.py
+++ b/tests/test_mednistdataset.py
@@ -25,6 +25,7 @@
class TestMedNISTDataset(unittest.TestCase):
+
@skip_if_quick
def test_values(self):
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_meta_affine.py b/tests/test_meta_affine.py
index b95ea3f1ac..95764a0c89 100644
--- a/tests/test_meta_affine.py
+++ b/tests/test_meta_affine.py
@@ -123,6 +123,7 @@ def _resample_to_affine(itk_obj, ref_obj):
@unittest.skipUnless(has_itk, "Requires itk package.")
class TestAffineConsistencyITK(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py
index 0cd0522036..1e0f188b63 100644
--- a/tests/test_meta_tensor.py
+++ b/tests/test_meta_tensor.py
@@ -50,6 +50,7 @@ def rand_string(min_len=5, max_len=10):
class TestMetaTensor(unittest.TestCase):
+
@staticmethod
def get_im(shape=None, dtype=None, device=None):
if shape is None:
diff --git a/tests/test_metatensor_integration.py b/tests/test_metatensor_integration.py
index 6a4c67d160..d647e47e74 100644
--- a/tests/test_metatensor_integration.py
+++ b/tests/test_metatensor_integration.py
@@ -39,6 +39,7 @@
@unittest.skipUnless(has_nib, "Requires nibabel package.")
class TestMetaTensorIntegration(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super().setUpClass()
diff --git a/tests/test_metrics_reloaded.py b/tests/test_metrics_reloaded.py
index 010326b87d..562693c07c 100644
--- a/tests/test_metrics_reloaded.py
+++ b/tests/test_metrics_reloaded.py
@@ -76,6 +76,7 @@
@unittest.skipIf(not has_metrics, "MetricsReloaded not available.")
class TestMetricsReloaded(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_BINARY)
def test_binary(self, input_param, input_data, expected_val):
metric = MetricsReloadedBinary(**input_param)
diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py
index 9178e0bccb..42116e8220 100644
--- a/tests/test_milmodel.py
+++ b/tests/test_milmodel.py
@@ -63,6 +63,7 @@
class TestMilModel(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_MILMODEL)
def test_shape(self, input_param, input_shape, expected_shape):
with skip_if_downloading_fails():
diff --git a/tests/test_mlp.py b/tests/test_mlp.py
index 8ad66ebc6e..54f70d3318 100644
--- a/tests/test_mlp.py
+++ b/tests/test_mlp.py
@@ -33,6 +33,7 @@
class TestMLPBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_MLP)
def test_shape(self, input_param, input_shape, expected_shape):
net = MLPBlock(**input_param)
diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py
index 66fca6bb7f..6af3d09fb2 100644
--- a/tests/test_mmar_download.py
+++ b/tests/test_mmar_download.py
@@ -116,6 +116,7 @@
@unittest.skip("deprecating mmar tests")
class TestMMMARDownload(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
@skip_if_quick
def test_download(self, idx):
diff --git a/tests/test_module_list.py b/tests/test_module_list.py
index 293da95d5a..d21ba53b7c 100644
--- a/tests/test_module_list.py
+++ b/tests/test_module_list.py
@@ -21,6 +21,7 @@
class TestAllImport(unittest.TestCase):
+
def test_public_api(self):
"""
This is to check "monai.__all__" should be consistent with
diff --git a/tests/test_monai_env_vars.py b/tests/test_monai_env_vars.py
index 6e0d6f0ddf..f5ef28a0ac 100644
--- a/tests/test_monai_env_vars.py
+++ b/tests/test_monai_env_vars.py
@@ -18,6 +18,7 @@
class TestMONAIEnvVars(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
diff --git a/tests/test_monai_utils_misc.py b/tests/test_monai_utils_misc.py
index 742c9e4047..a2a4ed62f7 100644
--- a/tests/test_monai_utils_misc.py
+++ b/tests/test_monai_utils_misc.py
@@ -40,11 +40,13 @@
class MiscClass:
+
def __init__(self, arg1, arg2, kwargs1=None, kwargs2=None):
pass
class TestToTupleOfDictionaries(unittest.TestCase):
+
@parameterized.expand(TO_TUPLE_OF_DICTIONARIES_TEST_CASES)
def test_to_tuple_of_dictionaries(self, dictionary, keys, expected):
self._test_to_tuple_of_dictionaries(dictionary, keys, expected)
@@ -61,6 +63,7 @@ def _test_to_tuple_of_dictionaries(self, dictionary, keys, expected):
class TestMiscKwargs(unittest.TestCase):
+
def test_kwargs(self):
present, extra_args = self._custom_user_function(MiscClass, 1, kwargs1="value1", kwargs2="value2")
self.assertEqual(present, True)
@@ -74,6 +77,7 @@ def _custom_user_function(self, cls, *args, **kwargs):
class TestCommandRunner(unittest.TestCase):
+
def setUp(self):
self.orig_flag = str(MONAIEnvVars.debug())
diff --git a/tests/test_mri_utils.py b/tests/test_mri_utils.py
index 2f67816e2e..aabf06d02e 100644
--- a/tests/test_mri_utils.py
+++ b/tests/test_mri_utils.py
@@ -27,6 +27,7 @@
class TestMRIUtils(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rss(self, test_data, res_data):
result = root_sum_of_squares(test_data, spatial_dim=1)
diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py
index 8b8acb2503..6681f266a8 100644
--- a/tests/test_multi_scale.py
+++ b/tests/test_multi_scale.py
@@ -52,6 +52,7 @@
class TestMultiScale(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = MultiScaleLoss(**input_param).forward(**input_data)
diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py
index 74a2daab9d..242326e242 100644
--- a/tests/test_net_adapter.py
+++ b/tests/test_net_adapter.py
@@ -42,6 +42,7 @@
class TestNetAdapter(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_param, input_shape, expected_shape):
spatial_dims = input_param["dim"]
diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py
index aca145a03d..4182501808 100644
--- a/tests/test_network_consistency.py
+++ b/tests/test_network_consistency.py
@@ -38,6 +38,7 @@
class TestNetworkConsistency(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py
index 2539d95fd5..4475d8aaab 100644
--- a/tests/test_nifti_endianness.py
+++ b/tests/test_nifti_endianness.py
@@ -46,6 +46,7 @@
class TestNiftiEndianness(unittest.TestCase):
+
def setUp(self):
self.im, _ = create_test_image_2d(100, 100)
self.fname = tempfile.NamedTemporaryFile(suffix=".nii.gz").name
diff --git a/tests/test_nifti_header_revise.py b/tests/test_nifti_header_revise.py
index 3d000160e1..411c783fb5 100644
--- a/tests/test_nifti_header_revise.py
+++ b/tests/test_nifti_header_revise.py
@@ -20,6 +20,7 @@
class TestRectifyHeaderSformQform(unittest.TestCase):
+
def test_revise_q(self):
img = nib.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4))
img.header.set_zooms((0.1, 0.2, 0.3))
diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py
index f45c2ac5a7..8543fcea30 100644
--- a/tests/test_nifti_rw.py
+++ b/tests/test_nifti_rw.py
@@ -72,6 +72,7 @@
class TestNiftiLoadRead(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_orientation(self, array, affine, reader_param, expected):
test_image = make_nifti_image(array, affine)
diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py
index 193b5cc4b2..72ebf579e1 100644
--- a/tests/test_normalize_intensity.py
+++ b/tests/test_normalize_intensity.py
@@ -83,6 +83,7 @@
class TestNormalizeIntensity(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_default(self, im_type):
im = im_type(self.imt.copy())
diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py
index 451269b1c4..229dcd00ff 100644
--- a/tests/test_normalize_intensityd.py
+++ b/tests/test_normalize_intensityd.py
@@ -51,6 +51,7 @@
class TestNormalizeIntensityd(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_image_normalize_intensityd(self, im_type):
key = "img"
diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py
index 4ff4577b72..e2196f1907 100644
--- a/tests/test_npzdictitemdataset.py
+++ b/tests/test_npzdictitemdataset.py
@@ -21,6 +21,7 @@
class TestNPZDictItemDataset(unittest.TestCase):
+
def test_load_stream(self):
dat0 = np.random.rand(10, 1, 4, 4)
dat1 = np.random.rand(10, 1, 4, 4)
diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py
index 01fabe65a8..649b9fa94d 100644
--- a/tests/test_nrrd_reader.py
+++ b/tests/test_nrrd_reader.py
@@ -48,6 +48,7 @@
@skipUnless(has_nrrd, "nrrd required")
class TestNrrdReader(unittest.TestCase):
+
def test_verify_suffix(self):
reader = NrrdReader()
self.assertFalse(reader.verify_suffix("test_image.nrd"))
diff --git a/tests/test_nuclick_transforms.py b/tests/test_nuclick_transforms.py
index fcdd362b01..a6e66c3658 100644
--- a/tests/test_nuclick_transforms.py
+++ b/tests/test_nuclick_transforms.py
@@ -179,6 +179,7 @@
class TestFilterImaged(unittest.TestCase):
+
@parameterized.expand([FILTER_IMAGE_TEST_CASE_1])
def test_correct_shape(self, arguments, input_data, expected_shape):
result = FilterImaged(**arguments)(input_data)
@@ -186,6 +187,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape):
class TestFlattenLabeld(unittest.TestCase):
+
@parameterized.expand([FLATTEN_LABEL_TEST_CASE_1, FLATTEN_LABEL_TEST_CASE_2, FLATTEN_LABEL_TEST_CASE_3])
def test_correct_num_labels(self, arguments, input_data, expected_result):
result = FlattenLabeld(**arguments)(input_data)
@@ -193,6 +195,7 @@ def test_correct_num_labels(self, arguments, input_data, expected_result):
class TestExtractPatchd(unittest.TestCase):
+
@parameterized.expand([EXTRACT_TEST_CASE_1, EXTRACT_TEST_CASE_2, EXTRACT_TEST_CASE_3])
def test_correct_patch_size(self, arguments, input_data, expected_shape):
result = ExtractPatchd(**arguments)(input_data)
@@ -205,6 +208,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestSplitLabelsd(unittest.TestCase):
+
@parameterized.expand([SPLIT_TEST_CASE_1, SPLIT_TEST_CASE_2])
def test_correct_results(self, arguments, input_data, expected_result):
result = SplitLabeld(**arguments)(input_data)
@@ -212,6 +216,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestGuidanceSignal(unittest.TestCase):
+
@parameterized.expand([GUIDANCE_TEST_CASE_1, GUIDANCE_TEST_CASE_2])
def test_correct_shape(self, arguments, input_data, expected_shape):
result = AddPointGuidanceSignald(**arguments)(input_data)
@@ -219,6 +224,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape):
class TestClickSignal(unittest.TestCase):
+
@parameterized.expand([CLICK_TEST_CASE_1, CLICK_TEST_CASE_2])
def test_correct_shape(self, arguments, input_data, expected_shape):
result = AddClickSignalsd(**arguments)(input_data)
@@ -226,6 +232,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape):
class TestPostFilterLabel(unittest.TestCase):
+
@parameterized.expand([LABEL_FILTER_TEST_CASE_1])
def test_correct_shape(self, arguments, input_data, expected_shape):
result = PostFilterLabeld(**arguments)(input_data)
@@ -233,6 +240,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape):
class TestAddLabelAsGuidance(unittest.TestCase):
+
@parameterized.expand([LABEL_GUIDANCE_TEST_CASE_1])
def test_correct_shape(self, arguments, input_data, expected_shape):
result = AddLabelAsGuidanced(**arguments)(input_data)
@@ -240,6 +248,7 @@ def test_correct_shape(self, arguments, input_data, expected_shape):
class TestSetLabelClass(unittest.TestCase):
+
@parameterized.expand([LABEL_CLASS_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = SetLabelClassd(**arguments)(input_data)
diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py
index eeff2922ad..6303598bb7 100644
--- a/tests/test_numpy_reader.py
+++ b/tests/test_numpy_reader.py
@@ -24,6 +24,7 @@
class TestNumpyReader(unittest.TestCase):
+
def test_npy(self):
test_data = np.random.randint(0, 256, size=[3, 4, 4])
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py
index 574fd49592..efd2906972 100644
--- a/tests/test_nvtx_decorator.py
+++ b/tests/test_nvtx_decorator.py
@@ -72,6 +72,7 @@
@unittest.skipUnless(has_nvtx, "Required torch._C._nvtx for NVTX Range!")
class TestNVTXRangeDecorator(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1])
def test_tranform_array(self, input):
transforms = Compose([Range("random flip")(Flip()), Range()(ToTensor())])
diff --git a/tests/test_nvtx_transform.py b/tests/test_nvtx_transform.py
index 3a5314c35f..af15c53d1b 100644
--- a/tests/test_nvtx_transform.py
+++ b/tests/test_nvtx_transform.py
@@ -43,6 +43,7 @@
class TestNVTXTransforms(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1, TEST_CASE_DICT_0, TEST_CASE_DICT_1])
@unittest.skipUnless(has_nvtx, "CUDA is required for NVTX!")
def test_nvtx_transfroms_alone(self, input):
diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py
index c7ac5ef533..d821c9bcd9 100644
--- a/tests/test_occlusion_sensitivity.py
+++ b/tests/test_occlusion_sensitivity.py
@@ -22,6 +22,7 @@
class DenseNetAdjoint(DenseNet121):
+
def __call__(self, x, adjoint_info):
if adjoint_info != 42:
raise ValueError
@@ -104,6 +105,7 @@ def __call__(self, x, adjoint_info):
class TestComputeOcclusionSensitivity(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape):
occ_sens = OcclusionSensitivity(**init_data)
diff --git a/tests/test_one_of.py b/tests/test_one_of.py
index 2909597507..ecf1cb3319 100644
--- a/tests/test_one_of.py
+++ b/tests/test_one_of.py
@@ -39,31 +39,37 @@
class X(Transform):
+
def __call__(self, x):
return x
class Y(Transform):
+
def __call__(self, x):
return x
class A(Transform):
+
def __call__(self, x):
return x + 1
class B(Transform):
+
def __call__(self, x):
return x + 2
class C(Transform):
+
def __call__(self, x):
return x + 3
class MapBase(MapTransform):
+
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn, self.inv_fn = None, None
@@ -76,12 +82,14 @@ def __call__(self, data):
class NonInv(MapBase):
+
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x * 2
class Inv(MapBase, InvertibleTransform):
+
def __call__(self, data):
d = deepcopy(dict(data))
for key in self.key_iterator(d):
@@ -98,6 +106,7 @@ def inverse(self, data):
class InvA(Inv):
+
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x + 1
@@ -105,6 +114,7 @@ def __init__(self, keys):
class InvB(Inv):
+
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x + 100
@@ -123,6 +133,7 @@ def __init__(self, keys):
class TestOneOf(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_normalize_weights(self, transforms, input_weights, expected_weights):
tr = OneOf(transforms, input_weights)
@@ -240,6 +251,7 @@ def test_one_of(self):
class TestOneOfAPITests(unittest.TestCase):
+
@staticmethod
def data_from_keys(keys):
if keys is None:
diff --git a/tests/test_optional_import.py b/tests/test_optional_import.py
index 03db7b3fc6..e7e1c03fd0 100644
--- a/tests/test_optional_import.py
+++ b/tests/test_optional_import.py
@@ -17,6 +17,7 @@
class TestOptionalImport(unittest.TestCase):
+
def test_default(self):
my_module, flag = optional_import("not_a_module")
self.assertFalse(flag)
diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py
index 824793f927..39c0a57877 100644
--- a/tests/test_ori_ras_lps.py
+++ b/tests/test_ori_ras_lps.py
@@ -38,6 +38,7 @@
class TestITKWriter(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_AFFINE)
def test_ras_to_lps(self, param, expected):
assert_allclose(orientation_ras_lps(param), expected)
diff --git a/tests/test_orientation.py b/tests/test_orientation.py
index aa1c326bdf..2f3334e622 100644
--- a/tests/test_orientation.py
+++ b/tests/test_orientation.py
@@ -177,6 +177,7 @@
class TestOrientationCase(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_ornt_meta(
self,
diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py
index cf4eb23d42..b885266c69 100644
--- a/tests/test_orientationd.py
+++ b/tests/test_orientationd.py
@@ -65,6 +65,7 @@
class TestOrientationdCase(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_orntd(
self, init_param, img: torch.Tensor, affine: torch.Tensor | None, expected_shape, expected_code, device
diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py
index db9e9c284d..1a4ea6c884 100644
--- a/tests/test_p3d_block.py
+++ b/tests/test_p3d_block.py
@@ -62,6 +62,7 @@
class TestP3D(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D)
def test_3d(self, input_param, input_shape, expected_shape):
net = P3DActiConvNormBlock(**input_param)
diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py
index cd98f29abf..ee6e001438 100644
--- a/tests/test_pad_collation.py
+++ b/tests/test_pad_collation.py
@@ -60,6 +60,7 @@ def _testing_collate(x):
class _Dataset(torch.utils.data.Dataset):
+
def __init__(self, images, labels, transforms):
self.images = images
self.labels = labels
@@ -73,6 +74,7 @@ def __getitem__(self, index):
class TestPadCollation(unittest.TestCase):
+
def setUp(self) -> None:
set_determinism(seed=0)
# image is non square to throw rotation errors
diff --git a/tests/test_pad_mode.py b/tests/test_pad_mode.py
index 722d5b573f..54ee2c6d75 100644
--- a/tests/test_pad_mode.py
+++ b/tests/test_pad_mode.py
@@ -23,6 +23,7 @@
@SkipIfBeforePyTorchVersion((1, 10, 1))
class TestPadMode(unittest.TestCase):
+
def test_pad(self):
expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)}
for t in (float, int, np.uint8, np.int16, np.float32, bool):
diff --git a/tests/test_partition_dataset.py b/tests/test_partition_dataset.py
index 8640d8cc73..c93a6c7682 100644
--- a/tests/test_partition_dataset.py
+++ b/tests/test_partition_dataset.py
@@ -118,6 +118,7 @@
class TestPartitionDataset(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
)
diff --git a/tests/test_partition_dataset_classes.py b/tests/test_partition_dataset_classes.py
index c4fa5ed199..4c13b2f463 100644
--- a/tests/test_partition_dataset_classes.py
+++ b/tests/test_partition_dataset_classes.py
@@ -76,6 +76,7 @@
class TestPartitionDatasetClasses(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, input_param, result):
self.assertListEqual(partition_dataset_classes(**input_param), result)
diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py
index eb705f0c61..9a81d84363 100644
--- a/tests/test_patch_dataset.py
+++ b/tests/test_patch_dataset.py
@@ -27,6 +27,7 @@ def identity(x):
class TestPatchDataset(unittest.TestCase):
+
def test_shape(self):
test_dataset = ["vwxyz", "hello", "world"]
n_per_image = len(test_dataset[0])
diff --git a/tests/test_patch_inferer.py b/tests/test_patch_inferer.py
index 032d22bb98..c6308224b0 100644
--- a/tests/test_patch_inferer.py
+++ b/tests/test_patch_inferer.py
@@ -245,6 +245,7 @@
class PatchInfererTests(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_0_TENSOR,
diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py
index cb9ebcf7e3..70e01eaaf4 100644
--- a/tests/test_patch_wsi_dataset.py
+++ b/tests/test_patch_wsi_dataset.py
@@ -128,6 +128,7 @@ def setUpModule():
class PatchWSIDatasetTests:
+
class Tests(unittest.TestCase):
backend = None
@@ -182,6 +183,7 @@ def test_read_patches_str_multi(self, input_parameters, expected):
@skipUnless(has_cim, "Requires cucim")
class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "cucim"
@@ -189,6 +191,7 @@ def setUpClass(cls):
@skipUnless(has_osl, "Requires openslide")
class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "openslide"
diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py
index 77ade984eb..f8610d9214 100644
--- a/tests/test_patchembedding.py
+++ b/tests/test_patchembedding.py
@@ -77,6 +77,7 @@
@SkipIfBeforePyTorchVersion((1, 11, 1))
class TestPatchEmbeddingBlock(unittest.TestCase):
+
def setUp(self):
self.threads = torch.get_num_threads()
torch.set_num_threads(4)
@@ -162,6 +163,7 @@ def test_ill_arg(self):
class TestPatchEmbed(unittest.TestCase):
+
def setUp(self):
self.threads = torch.get_num_threads()
torch.set_num_threads(4)
diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py
index 7ddad4ad6f..26941c6abb 100644
--- a/tests/test_pathology_he_stain.py
+++ b/tests/test_pathology_he_stain.py
@@ -73,6 +73,7 @@
class TestExtractHEStains(unittest.TestCase):
+
@parameterized.expand(
[NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1]
)
@@ -145,6 +146,7 @@ def test_result_value(self, image, expected_data):
class TestNormalizeHEStains(unittest.TestCase):
+
@parameterized.expand(
[NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1]
)
diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py
index 07db1c3e48..975dc4ffb8 100644
--- a/tests/test_pathology_he_stain_dict.py
+++ b/tests/test_pathology_he_stain_dict.py
@@ -67,6 +67,7 @@
class TestExtractHEStainsD(unittest.TestCase):
+
@parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1])
def test_transparent_image(self, image):
"""
@@ -140,6 +141,7 @@ def test_result_value(self, image, expected_data):
class TestNormalizeHEStainsD(unittest.TestCase):
+
@parameterized.expand([NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1])
def test_transparent_image(self, image):
"""
diff --git a/tests/test_pathology_prob_nms.py b/tests/test_pathology_prob_nms.py
index 0053500437..b3d7da2c1d 100644
--- a/tests/test_pathology_prob_nms.py
+++ b/tests/test_pathology_prob_nms.py
@@ -43,6 +43,7 @@
class TestPathologyProbNMS(unittest.TestCase):
+
@parameterized.expand([TEST_CASES_2D, TEST_CASES_3D])
def test_output(self, class_args, call_args, probs_map, expected):
nms = PathologyProbNMS(**class_args)
diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py
index 7e4860e7f9..ba204af697 100644
--- a/tests/test_perceptual_loss.py
+++ b/tests/test_perceptual_loss.py
@@ -52,6 +52,7 @@
@unittest.skipUnless(has_torchvision, "Requires torchvision")
@skip_if_quick
class TestPerceptualLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py
index 1b8245e318..b7bf2fbb11 100644
--- a/tests/test_persistentdataset.py
+++ b/tests/test_persistentdataset.py
@@ -45,6 +45,7 @@
class _InplaceXform(Transform):
+
def __call__(self, data):
if data:
data[0] = data[0] + np.pi
@@ -54,6 +55,7 @@ def __call__(self, data):
class TestDataset(unittest.TestCase):
+
def test_cache(self):
"""testing no inplace change to the hashed item"""
items = [[list(range(i))] for i in range(5)]
diff --git a/tests/test_persistentdataset_dist.py b/tests/test_persistentdataset_dist.py
index e69c32b1eb..c369af9e92 100644
--- a/tests/test_persistentdataset_dist.py
+++ b/tests/test_persistentdataset_dist.py
@@ -25,6 +25,7 @@
class _InplaceXform(Transform):
+
def __call__(self, data):
if data:
data[0] = data[0] + np.pi
@@ -34,6 +35,7 @@ def __call__(self, data):
class TestDistDataset(DistTestCase):
+
def setUp(self):
self.tempdir = tempfile.mkdtemp()
@@ -58,6 +60,7 @@ def test_mp_dataset(self):
class TestDistCreateDataset(DistTestCase):
+
def setUp(self):
self.tempdir = tempfile.mkdtemp()
diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py
index 98a5018d8e..6f872a4776 100644
--- a/tests/test_phl_cpu.py
+++ b/tests/test_phl_cpu.py
@@ -242,6 +242,7 @@
@skip_if_no_cpp_extension
class PHLFilterTestCaseCpu(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cpu(self, test_case_description, sigmas, input, features, expected):
# Create input tensors
diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py
index 0ddfd5eaae..b410ea8722 100644
--- a/tests/test_phl_cuda.py
+++ b/tests/test_phl_cuda.py
@@ -150,6 +150,7 @@
@skip_if_no_cuda
@skip_if_no_cpp_extension
class PHLFilterTestCaseCuda(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cuda(self, test_case_description, sigmas, input, features, expected):
# Create input tensors
diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py
index dfa5eb725d..078812513d 100644
--- a/tests/test_pil_reader.py
+++ b/tests/test_pil_reader.py
@@ -37,6 +37,7 @@
class TestPNGReader(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True):
test_image = np.random.randint(0, 256, size=data_shape)
diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py
index 180a6c3443..16241853b3 100644
--- a/tests/test_plot_2d_or_3d_image.py
+++ b/tests/test_plot_2d_or_3d_image.py
@@ -40,6 +40,7 @@
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
@SkipIfBeforePyTorchVersion((1, 13)) # issue 6683
class TestPlot2dOr3dImage(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_tb_image(self, shape):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py
index 0b6e8184ea..058cd616cb 100644
--- a/tests/test_png_rw.py
+++ b/tests/test_png_rw.py
@@ -22,6 +22,7 @@
class TestPngWrite(unittest.TestCase):
+
def test_write_gray(self):
with tempfile.TemporaryDirectory() as out_dir:
image_name = os.path.join(out_dir, "test.png")
diff --git a/tests/test_polyval.py b/tests/test_polyval.py
index 113c862cb3..f0215678af 100644
--- a/tests/test_polyval.py
+++ b/tests/test_polyval.py
@@ -31,6 +31,7 @@
class TestPolyval(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_floats(self, coef, x, expected):
result = polyval(coef, x)
diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py
index e440f5cfe3..d5a5fbf57e 100644
--- a/tests/test_prepare_batch_default.py
+++ b/tests/test_prepare_batch_default.py
@@ -20,11 +20,13 @@
class TestNet(torch.nn.Module):
+
def forward(self, x: torch.Tensor):
return x
class TestPrepareBatchDefault(unittest.TestCase):
+
def test_dict_content(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = [
diff --git a/tests/test_prepare_batch_default_dist.py b/tests/test_prepare_batch_default_dist.py
index d015cf4b2f..0c53a74834 100644
--- a/tests/test_prepare_batch_default_dist.py
+++ b/tests/test_prepare_batch_default_dist.py
@@ -43,11 +43,13 @@
class TestNet(torch.nn.Module):
+
def forward(self, x: torch.Tensor):
return x
class DistributedPrepareBatchDefault(DistTestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
def test_compute(self, dataloaders):
diff --git a/tests/test_prepare_batch_extra_input.py b/tests/test_prepare_batch_extra_input.py
index 1769a19e4a..f20c6e7352 100644
--- a/tests/test_prepare_batch_extra_input.py
+++ b/tests/test_prepare_batch_extra_input.py
@@ -36,11 +36,13 @@
class TestNet(torch.nn.Module):
+
def forward(self, x: torch.Tensor, t1=None, t2=None, t3=None):
return {"x": x, "t1": t1, "t2": t2, "t3": t3}
class TestPrepareBatchExtraInput(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
def test_content(self, input_args, expected_value):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
diff --git a/tests/test_prepare_batch_hovernet.py b/tests/test_prepare_batch_hovernet.py
index 5a7080a225..773fcb53bf 100644
--- a/tests/test_prepare_batch_hovernet.py
+++ b/tests/test_prepare_batch_hovernet.py
@@ -28,11 +28,13 @@
class TestNet(torch.nn.Module):
+
def forward(self, x: torch.Tensor):
return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16}
class TestPrepareBatchHoVerNet(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0])
def test_content(self, input_args, expected_value):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py
index 9bca24cef3..46ed461f7d 100644
--- a/tests/test_preset_filters.py
+++ b/tests/test_preset_filters.py
@@ -63,6 +63,7 @@
class _TestFilter:
+
def test_init(self, spatial_dims, size, expected):
test_filter = self.filter_class(spatial_dims=spatial_dims, size=size)
torch.testing.assert_allclose(expected, test_filter.filter)
@@ -75,6 +76,7 @@ def test_forward(self):
class TestApplyFilter(unittest.TestCase):
+
def test_init_and_forward_2d(self):
filter_2d = torch.ones(3, 3)
image_2d = torch.ones(1, 3, 3)
@@ -91,6 +93,7 @@ def test_init_and_forward_3d(self):
class MeanFilterTestCase(_TestFilter, unittest.TestCase):
+
def setUp(self) -> None:
self.filter_class = MeanFilter
@@ -100,6 +103,7 @@ def test_init(self, spatial_dims, size, expected):
class LaplaceFilterTestCase(_TestFilter, unittest.TestCase):
+
def setUp(self) -> None:
self.filter_class = LaplaceFilter
@@ -109,6 +113,7 @@ def test_init(self, spatial_dims, size, expected):
class EllipticalTestCase(_TestFilter, unittest.TestCase):
+
def setUp(self) -> None:
self.filter_class = EllipticalFilter
@@ -118,6 +123,7 @@ def test_init(self, spatial_dims, size, expected):
class SharpenTestCase(_TestFilter, unittest.TestCase):
+
def setUp(self) -> None:
self.filter_class = SharpenFilter
diff --git a/tests/test_print_info.py b/tests/test_print_info.py
index bb748c3f7b..aa152e183c 100644
--- a/tests/test_print_info.py
+++ b/tests/test_print_info.py
@@ -17,6 +17,7 @@
class TestPrintInfo(unittest.TestCase):
+
def test_print_info(self):
print_debug_info()
diff --git a/tests/test_print_transform_backends.py b/tests/test_print_transform_backends.py
index 4cd93c3fb2..2072aa4cfa 100644
--- a/tests/test_print_transform_backends.py
+++ b/tests/test_print_transform_backends.py
@@ -17,6 +17,7 @@
class TestPrintTransformBackends(unittest.TestCase):
+
def test_get_number_of_conversions(self):
tr_t_or_np, *_ = get_transform_backends()
self.assertGreater(len(tr_t_or_np), 0)
diff --git a/tests/test_probnms.py b/tests/test_probnms.py
index 8da5396fac..2b52583ad4 100644
--- a/tests/test_probnms.py
+++ b/tests/test_probnms.py
@@ -61,6 +61,7 @@
class TestProbNMS(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_output(self, class_args, probs_map, expected):
nms = ProbNMS(**class_args)
diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py
index 1f0288811e..aeb32bdb79 100644
--- a/tests/test_probnmsd.py
+++ b/tests/test_probnmsd.py
@@ -68,6 +68,7 @@
class TestProbNMS(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_output(self, class_args, probs_map, expected):
nms = ProbNMSD(keys="prob_map", **class_args)
diff --git a/tests/test_profiling.py b/tests/test_profiling.py
index 2b93fae196..6bee7ba262 100644
--- a/tests/test_profiling.py
+++ b/tests/test_profiling.py
@@ -29,6 +29,7 @@
class TestWorkflowProfiler(unittest.TestCase):
+
def setUp(self):
super().setUp()
diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py
index 4c8c032c80..147707d2c0 100644
--- a/tests/test_pytorch_version_after.py
+++ b/tests/test_pytorch_version_after.py
@@ -38,6 +38,7 @@
class TestPytorchVersionCompare(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_compare(self, a, b, p, current, expected=True):
"""Test pytorch_after with a and b"""
diff --git a/tests/test_query_memory.py b/tests/test_query_memory.py
index 5e57913acb..77c34ede39 100644
--- a/tests/test_query_memory.py
+++ b/tests/test_query_memory.py
@@ -17,6 +17,7 @@
class TestQueryMemory(unittest.TestCase):
+
def test_output_str(self):
self.assertTrue(isinstance(query_memory(2), str))
all_device = query_memory(-1)
diff --git a/tests/test_quicknat.py b/tests/test_quicknat.py
index b4b89b7d62..f6786405d2 100644
--- a/tests/test_quicknat.py
+++ b/tests/test_quicknat.py
@@ -38,6 +38,7 @@
@unittest.skipUnless(has_se, "squeeze_and_excitation not installed")
class TestQuicknat(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_rand_adjust_contrast.py b/tests/test_rand_adjust_contrast.py
index bfeedc2fcf..72d0df141e 100644
--- a/tests/test_rand_adjust_contrast.py
+++ b/tests/test_rand_adjust_contrast.py
@@ -25,6 +25,7 @@
class TestRandAdjustContrast(NumpyImageTestCase2D):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_correct_results(self, gamma):
adjuster = RandAdjustContrast(prob=1.0, gamma=gamma)
diff --git a/tests/test_rand_adjust_contrastd.py b/tests/test_rand_adjust_contrastd.py
index 4037266da4..bbd5c22009 100644
--- a/tests/test_rand_adjust_contrastd.py
+++ b/tests/test_rand_adjust_contrastd.py
@@ -25,6 +25,7 @@
class TestRandAdjustContrastd(NumpyImageTestCase2D):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_correct_results(self, gamma):
adjuster = RandAdjustContrastd("img", prob=1.0, gamma=gamma)
diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py
index 915b14bf51..f37f7827bb 100644
--- a/tests/test_rand_affine.py
+++ b/tests/test_rand_affine.py
@@ -140,6 +140,7 @@
class TestRandAffine(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_affine(self, input_param, input_data, expected_val):
g = RandAffine(**input_param)
diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py
index 113987a85c..91558ebd03 100644
--- a/tests/test_rand_affine_grid.py
+++ b/tests/test_rand_affine_grid.py
@@ -198,6 +198,7 @@
class TestRandAffineGrid(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_affine_grid(self, input_param, input_data, expected_val):
g = RandAffineGrid(**input_param)
diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py
index a607029c1a..20c50954e2 100644
--- a/tests/test_rand_affined.py
+++ b/tests/test_rand_affined.py
@@ -216,6 +216,7 @@
class TestRandAffined(unittest.TestCase):
+
@parameterized.expand(x + [y] for x, y in itertools.product(TESTS, (False, True)))
def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
set_track_meta(track_meta)
diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py
index 81e42372db..9c465a0bcb 100644
--- a/tests/test_rand_axis_flip.py
+++ b/tests/test_rand_axis_flip.py
@@ -23,6 +23,7 @@
class TestRandAxisFlip(NumpyImageTestCase2D):
+
def test_correct_results(self):
for p in TEST_NDARRAYS_ALL:
flip = RandAxisFlip(prob=1.0)
diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py
index 75357b23e1..d3abef1be4 100644
--- a/tests/test_rand_axis_flipd.py
+++ b/tests/test_rand_axis_flipd.py
@@ -23,6 +23,7 @@
class TestRandAxisFlip(NumpyImageTestCase3D):
+
def test_correct_results(self):
for p in TEST_NDARRAYS_ALL:
flip = RandAxisFlipd(keys="img", prob=1.0)
diff --git a/tests/test_rand_bias_field.py b/tests/test_rand_bias_field.py
index 16f615146f..333a9ecba5 100644
--- a/tests/test_rand_bias_field.py
+++ b/tests/test_rand_bias_field.py
@@ -30,6 +30,7 @@
class TestRandBiasField(unittest.TestCase):
+
@parameterized.expand([TEST_CASES_2D, TEST_CASES_3D])
def test_output_shape(self, class_args, img_shape):
for p in TEST_NDARRAYS:
diff --git a/tests/test_rand_bias_fieldd.py b/tests/test_rand_bias_fieldd.py
index 2b8a60289d..1f174fa397 100644
--- a/tests/test_rand_bias_fieldd.py
+++ b/tests/test_rand_bias_fieldd.py
@@ -28,6 +28,7 @@
class TestRandBiasFieldd(unittest.TestCase):
+
@parameterized.expand([TEST_CASES_2D, TEST_CASES_3D])
def test_output_shape(self, class_args, img_shape):
key = "img"
diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py
index 8c3876f10b..ac857f9184 100644
--- a/tests/test_rand_coarse_dropout.py
+++ b/tests/test_rand_coarse_dropout.py
@@ -63,6 +63,7 @@
class TestRandCoarseDropout(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]
)
diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py
index 7b16f992b7..bfc6a2f27f 100644
--- a/tests/test_rand_coarse_dropoutd.py
+++ b/tests/test_rand_coarse_dropoutd.py
@@ -63,6 +63,7 @@
class TestRandCoarseDropoutd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_value(self, input_param, input_data):
dropout = RandCoarseDropoutd(**input_param)
diff --git a/tests/test_rand_coarse_shuffle.py b/tests/test_rand_coarse_shuffle.py
index adfb722b42..39e62c22a8 100644
--- a/tests/test_rand_coarse_shuffle.py
+++ b/tests/test_rand_coarse_shuffle.py
@@ -52,6 +52,7 @@
class TestRandCoarseShuffle(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shuffle(self, input_param, input_data, expected_val):
g = RandCoarseShuffle(**input_param)
diff --git a/tests/test_rand_coarse_shuffled.py b/tests/test_rand_coarse_shuffled.py
index 3b5a1434f4..f49066efd9 100644
--- a/tests/test_rand_coarse_shuffled.py
+++ b/tests/test_rand_coarse_shuffled.py
@@ -46,6 +46,7 @@
class TestRandCoarseShuffled(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shuffle(self, input_param, input_data, expected_val):
g = RandCoarseShuffled(**input_param)
diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py
index 88d2631ca5..743b894d75 100644
--- a/tests/test_rand_crop_by_label_classes.py
+++ b/tests/test_rand_crop_by_label_classes.py
@@ -127,6 +127,7 @@
class TestRandCropByLabelClasses(unittest.TestCase):
+
@parameterized.expand(TESTS_INDICES + TESTS_SHAPE)
def test_type_shape(self, input_param, input_data, expected_type, expected_shape):
result = RandCropByLabelClasses(**input_param)(**input_data)
diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py
index 748f26f1ff..8908c456ee 100644
--- a/tests/test_rand_crop_by_label_classesd.py
+++ b/tests/test_rand_crop_by_label_classesd.py
@@ -120,6 +120,7 @@
class TestRandCropByLabelClassesd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_type_shape(self, input_param, input_data, expected_type, expected_shape):
result = RandCropByLabelClassesd(**input_param)(input_data)
diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py
index 98af6b0b5e..66e7a5e849 100644
--- a/tests/test_rand_crop_by_pos_neg_label.py
+++ b/tests/test_rand_crop_by_pos_neg_label.py
@@ -96,6 +96,7 @@
class TestRandCropByPosNegLabel(unittest.TestCase):
+
@staticmethod
def convert_data_type(im_type, d, keys=("img", "image", "label")):
out = deepcopy(d)
diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py
index 1b57548d12..11381e226d 100644
--- a/tests/test_rand_crop_by_pos_neg_labeld.py
+++ b/tests/test_rand_crop_by_pos_neg_labeld.py
@@ -107,6 +107,7 @@
class TestRandCropByPosNegLabeld(unittest.TestCase):
+
@staticmethod
def convert_data_type(im_type, d, keys=("img", "image", "label")):
out = deepcopy(d)
diff --git a/tests/test_rand_cucim_dict_transform.py b/tests/test_rand_cucim_dict_transform.py
index 33e0667723..3f473897dd 100644
--- a/tests/test_rand_cucim_dict_transform.py
+++ b/tests/test_rand_cucim_dict_transform.py
@@ -78,6 +78,7 @@
@unittest.skipUnless(HAS_CUPY, "CuPy is required.")
@unittest.skipUnless(has_cut, "cuCIM transforms are required.")
class TestRandCuCIMDict(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_COLOR_JITTER_1,
diff --git a/tests/test_rand_cucim_transform.py b/tests/test_rand_cucim_transform.py
index 37d8e29f1d..ce731a05ae 100644
--- a/tests/test_rand_cucim_transform.py
+++ b/tests/test_rand_cucim_transform.py
@@ -78,6 +78,7 @@
@unittest.skipUnless(HAS_CUPY, "CuPy is required.")
@unittest.skipUnless(has_cut, "cuCIM transforms are required.")
class TestRandCuCIM(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_COLOR_JITTER_1,
diff --git a/tests/test_rand_deform_grid.py b/tests/test_rand_deform_grid.py
index 58b64ae596..88fc1333ec 100644
--- a/tests/test_rand_deform_grid.py
+++ b/tests/test_rand_deform_grid.py
@@ -126,6 +126,7 @@
class TestRandDeformGrid(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_rand_deform_grid(self, input_param, input_data, expected_val):
g = RandDeformGrid(**input_param)
diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py
index c59052854f..1f3d389a93 100644
--- a/tests/test_rand_elastic_2d.py
+++ b/tests/test_rand_elastic_2d.py
@@ -110,6 +110,7 @@
class TestRand2DElastic(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_2d_elastic(self, input_param, input_data, expected_val):
g = Rand2DElastic(**input_param)
diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py
index 0ff3ef6129..5bfa8a6e83 100644
--- a/tests/test_rand_elastic_3d.py
+++ b/tests/test_rand_elastic_3d.py
@@ -86,6 +86,7 @@
class TestRand3DElastic(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_3d_elastic(self, input_param, input_data, expected_val):
g = Rand3DElastic(**input_param)
diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py
index d0fbd5aa88..10aa116192 100644
--- a/tests/test_rand_elasticd_2d.py
+++ b/tests/test_rand_elasticd_2d.py
@@ -160,6 +160,7 @@
class TestRand2DElasticd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_2d_elasticd(self, input_param, input_data, expected_val):
g = Rand2DElasticd(**input_param)
diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py
index e058293584..3838f43f29 100644
--- a/tests/test_rand_elasticd_3d.py
+++ b/tests/test_rand_elasticd_3d.py
@@ -139,6 +139,7 @@
class TestRand3DElasticd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_3d_elasticd(self, input_param, input_data, expected_val):
g = Rand3DElasticd(**input_param)
diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py
index c3b0bfdede..faeae94cab 100644
--- a/tests/test_rand_flip.py
+++ b/tests/test_rand_flip.py
@@ -28,6 +28,7 @@
class TestRandFlip(NumpyImageTestCase2D):
+
@parameterized.expand(INVALID_CASES)
def test_invalid_inputs(self, _, spatial_axis, raises):
with self.assertRaises(raises):
diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py
index be5394c172..a34aa58ed2 100644
--- a/tests/test_rand_flipd.py
+++ b/tests/test_rand_flipd.py
@@ -26,6 +26,7 @@
class TestRandFlipd(NumpyImageTestCase2D):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, spatial_axis):
for p in TEST_NDARRAYS_ALL:
diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py
index 7d4d04ff3f..a56e54fe31 100644
--- a/tests/test_rand_gaussian_noise.py
+++ b/tests/test_rand_gaussian_noise.py
@@ -27,6 +27,7 @@
class TestRandGaussianNoise(NumpyImageTestCase2D):
+
@parameterized.expand(TESTS)
def test_correct_results(self, _, im_type, mean, std):
seed = 0
diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py
index 24fc19f226..bcbed98b5a 100644
--- a/tests/test_rand_gaussian_noised.py
+++ b/tests/test_rand_gaussian_noised.py
@@ -29,6 +29,7 @@
class TestRandGaussianNoised(NumpyImageTestCase2D):
+
@parameterized.expand(TESTS)
def test_correct_results(self, _, im_type, keys, mean, std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64)
diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py
index 8dff69cd4c..ee8604c14b 100644
--- a/tests/test_rand_gaussian_sharpen.py
+++ b/tests/test_rand_gaussian_sharpen.py
@@ -128,6 +128,7 @@
class TestRandGaussianSharpen(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
converter = RandGaussianSharpen(**arguments)
diff --git a/tests/test_rand_gaussian_sharpend.py b/tests/test_rand_gaussian_sharpend.py
index 4c32880053..b9bae529db 100644
--- a/tests/test_rand_gaussian_sharpend.py
+++ b/tests/test_rand_gaussian_sharpend.py
@@ -131,6 +131,7 @@
class TestRandGaussianSharpend(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
converter = RandGaussianSharpend(**arguments)
diff --git a/tests/test_rand_gaussian_smooth.py b/tests/test_rand_gaussian_smooth.py
index 9fb91a38a1..8bb36ca0fa 100644
--- a/tests/test_rand_gaussian_smooth.py
+++ b/tests/test_rand_gaussian_smooth.py
@@ -86,6 +86,7 @@
class TestRandGaussianSmooth(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
converter = RandGaussianSmooth(**arguments)
diff --git a/tests/test_rand_gaussian_smoothd.py b/tests/test_rand_gaussian_smoothd.py
index d312494e46..a93b355184 100644
--- a/tests/test_rand_gaussian_smoothd.py
+++ b/tests/test_rand_gaussian_smoothd.py
@@ -86,6 +86,7 @@
class TestRandGaussianSmoothd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
converter = RandGaussianSmoothd(**arguments)
diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py
index a0d18ae7f3..4befeffbe2 100644
--- a/tests/test_rand_gibbs_noise.py
+++ b/tests/test_rand_gibbs_noise.py
@@ -32,6 +32,7 @@
class TestRandGibbsNoise(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py
index 4120f967e2..6580189af6 100644
--- a/tests/test_rand_gibbs_noised.py
+++ b/tests/test_rand_gibbs_noised.py
@@ -34,6 +34,7 @@
class TestRandGibbsNoised(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_rand_grid_distortion.py b/tests/test_rand_grid_distortion.py
index 8131a2382a..e07c311b25 100644
--- a/tests/test_rand_grid_distortion.py
+++ b/tests/test_rand_grid_distortion.py
@@ -84,6 +84,7 @@
class TestRandGridDistortion(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val):
g = RandGridDistortion(**input_param)
diff --git a/tests/test_rand_grid_distortiond.py b/tests/test_rand_grid_distortiond.py
index 9f8ed3b9e6..f28e0ae86e 100644
--- a/tests/test_rand_grid_distortiond.py
+++ b/tests/test_rand_grid_distortiond.py
@@ -77,6 +77,7 @@
class TestRandGridDistortiond(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask):
g = RandGridDistortiond(**input_param)
diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py
index 494330584a..26863f01b2 100644
--- a/tests/test_rand_grid_patch.py
+++ b/tests/test_rand_grid_patch.py
@@ -105,6 +105,7 @@
class TestRandGridPatch(unittest.TestCase):
+
def setUp(self):
set_determinism(seed=1234)
diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py
index 23ca4a7881..031e834512 100644
--- a/tests/test_rand_grid_patchd.py
+++ b/tests/test_rand_grid_patchd.py
@@ -85,6 +85,7 @@
class TestRandGridPatchd(unittest.TestCase):
+
def setUp(self):
set_determinism(seed=1234)
diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py
index 318dad9dfa..785e24e53b 100644
--- a/tests/test_rand_histogram_shift.py
+++ b/tests/test_rand_histogram_shift.py
@@ -56,6 +56,7 @@
class TestRandHistogramShift(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_histogram_shift(self, input_param, input_data, expected_val):
g = RandHistogramShift(**input_param)
diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py
index 45e81ab012..fced270e90 100644
--- a/tests/test_rand_histogram_shiftd.py
+++ b/tests/test_rand_histogram_shiftd.py
@@ -61,6 +61,7 @@
class TestRandHistogramShiftD(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_histogram_shiftd(self, input_param, input_data, expected_val):
g = RandHistogramShiftd(**input_param)
diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py
index 4e7d59329b..7a9dd4288d 100644
--- a/tests/test_rand_k_space_spike_noise.py
+++ b/tests/test_rand_k_space_spike_noise.py
@@ -29,6 +29,7 @@
class TestRandKSpaceSpikeNoise(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py
index 3e1c11b2d9..86d4256637 100644
--- a/tests/test_rand_k_space_spike_noised.py
+++ b/tests/test_rand_k_space_spike_noised.py
@@ -30,6 +30,7 @@
class TestKSpaceSpikeNoised(unittest.TestCase):
+
def setUp(self):
set_determinism(0)
super().setUp()
diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py
index 1f14499bc0..98a324aec5 100644
--- a/tests/test_rand_lambda.py
+++ b/tests/test_rand_lambda.py
@@ -37,6 +37,7 @@ def __call__(self, data):
class TestRandLambda(unittest.TestCase):
+
def check(self, tr: RandLambda, img, img_orig_type, out, expected=None):
# input shouldn't change
self.assertIsInstance(img, img_orig_type)
diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py
index 6b60a3fe70..5247d79843 100644
--- a/tests/test_rand_lambdad.py
+++ b/tests/test_rand_lambdad.py
@@ -37,6 +37,7 @@ def __call__(self, data):
class TestRandLambdad(unittest.TestCase):
+
def check(self, tr: RandLambdad, input: dict, out: dict, expected: dict):
if isinstance(input["img"], MetaTensor):
self.assertEqual(len(input["img"].applied_operations), 0)
diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py
index fe7135835e..8dd1c48e29 100644
--- a/tests/test_rand_rician_noise.py
+++ b/tests/test_rand_rician_noise.py
@@ -27,6 +27,7 @@
class TestRandRicianNoise(NumpyImageTestCase2D):
+
@parameterized.expand(TESTS)
def test_correct_results(self, _, in_type, mean, std):
seed = 0
diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py
index ae0acab4eb..a190ba866d 100644
--- a/tests/test_rand_rician_noised.py
+++ b/tests/test_rand_rician_noised.py
@@ -29,6 +29,7 @@
class TestRandRicianNoisedNumpy(NumpyImageTestCase2D):
+
@parameterized.expand(TESTS)
def test_correct_results(self, _, in_type, keys, mean, std):
rician_fn = RandRicianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64)
diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py
index ca3eda3b12..c54229dcfe 100644
--- a/tests/test_rand_rotate.py
+++ b/tests/test_rand_rotate.py
@@ -73,6 +73,7 @@
class TestRandRotate2D(NumpyImageTestCase2D):
+
@parameterized.expand(TEST_CASES_2D)
def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):
init_param = {
@@ -112,6 +113,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode,
@unittest.skipIf(USE_COMPILED, "unit tests not for compiled version.")
class TestRandRotate3D(NumpyImageTestCase3D):
+
@parameterized.expand(TEST_CASES_3D)
def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):
init_param = {
@@ -146,6 +148,7 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode,
class TestRandRotateDtype(NumpyImageTestCase2D):
+
@parameterized.expand(TEST_CASES_2D)
def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):
rotate_fn = RandRotate(
diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py
index 88f88bf422..be2e658b78 100644
--- a/tests/test_rand_rotate90.py
+++ b/tests/test_rand_rotate90.py
@@ -23,6 +23,7 @@
class TestRandRotate90(NumpyImageTestCase2D):
+
def test_default(self):
rotate = RandRotate90()
for p in TEST_NDARRAYS_ALL:
diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py
index 23e9025c08..02836b5dd8 100644
--- a/tests/test_rand_rotate90d.py
+++ b/tests/test_rand_rotate90d.py
@@ -23,6 +23,7 @@
class TestRandRotate90d(NumpyImageTestCase2D):
+
def test_default(self):
key = "test"
rotate = RandRotate90d(keys=key)
diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py
index a5a377b02f..71d0f67b63 100644
--- a/tests/test_rand_rotated.py
+++ b/tests/test_rand_rotated.py
@@ -109,6 +109,7 @@
class TestRandRotated2D(NumpyImageTestCase2D):
+
@parameterized.expand(TEST_CASES_2D)
def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):
init_param = {
@@ -153,6 +154,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode,
@unittest.skipIf(USE_COMPILED, "unit tests not for compiled version.")
class TestRandRotated3D(NumpyImageTestCase3D):
+
@parameterized.expand(TEST_CASES_3D)
def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):
init_param = {
diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py
index a857c0cefb..7e999c00b3 100644
--- a/tests/test_rand_scale_intensity.py
+++ b/tests/test_rand_scale_intensity.py
@@ -21,6 +21,7 @@
class TestRandScaleIntensity(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value(self, p):
scaler = RandScaleIntensity(factors=0.5, prob=1.0)
diff --git a/tests/test_rand_scale_intensity_fixed_mean.py b/tests/test_rand_scale_intensity_fixed_mean.py
index f43adab32f..9324c711fa 100644
--- a/tests/test_rand_scale_intensity_fixed_mean.py
+++ b/tests/test_rand_scale_intensity_fixed_mean.py
@@ -21,6 +21,7 @@
class TestRandScaleIntensity(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value(self, p):
scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5)
diff --git a/tests/test_rand_scale_intensity_fixed_meand.py b/tests/test_rand_scale_intensity_fixed_meand.py
index c85c764a55..8c127ac130 100644
--- a/tests/test_rand_scale_intensity_fixed_meand.py
+++ b/tests/test_rand_scale_intensity_fixed_meand.py
@@ -20,6 +20,7 @@
class TestRandScaleIntensityFixedMeand(NumpyImageTestCase2D):
+
def test_value(self):
key = "img"
for p in TEST_NDARRAYS:
diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py
index 8d928ac157..32c96f0313 100644
--- a/tests/test_rand_scale_intensityd.py
+++ b/tests/test_rand_scale_intensityd.py
@@ -20,6 +20,7 @@
class TestRandScaleIntensityd(NumpyImageTestCase2D):
+
def test_value(self):
key = "img"
for p in TEST_NDARRAYS:
diff --git a/tests/test_rand_shift_intensity.py b/tests/test_rand_shift_intensity.py
index 01ac55f7b8..907773ccf5 100644
--- a/tests/test_rand_shift_intensity.py
+++ b/tests/test_rand_shift_intensity.py
@@ -21,6 +21,7 @@
class TestRandShiftIntensity(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value(self, p):
shifter = RandShiftIntensity(offsets=1.0, prob=1.0)
diff --git a/tests/test_rand_shift_intensityd.py b/tests/test_rand_shift_intensityd.py
index 7522676eb0..51675e324c 100644
--- a/tests/test_rand_shift_intensityd.py
+++ b/tests/test_rand_shift_intensityd.py
@@ -21,6 +21,7 @@
class TestRandShiftIntensityd(NumpyImageTestCase2D):
+
def test_value(self):
key = "img"
for p in TEST_NDARRAYS:
diff --git a/tests/test_rand_simulate_low_resolution.py b/tests/test_rand_simulate_low_resolution.py
index 7d05faad36..6aa586fb0b 100644
--- a/tests/test_rand_simulate_low_resolution.py
+++ b/tests/test_rand_simulate_low_resolution.py
@@ -71,6 +71,7 @@
class TestRandGaussianSmooth(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
randsimlowres = RandSimulateLowResolution(**arguments)
diff --git a/tests/test_rand_simulate_low_resolutiond.py b/tests/test_rand_simulate_low_resolutiond.py
index f058ec3b2b..5ec84eba1d 100644
--- a/tests/test_rand_simulate_low_resolutiond.py
+++ b/tests/test_rand_simulate_low_resolutiond.py
@@ -60,6 +60,7 @@
class TestRandGaussianSmoothd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
converter = RandSimulateLowResolutiond(**arguments)
diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py
index b37dacd643..cb53e94b7d 100644
--- a/tests/test_rand_spatial_crop_samplesd.py
+++ b/tests/test_rand_spatial_crop_samplesd.py
@@ -90,6 +90,7 @@
class TestRandSpatialCropSamplesd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, *TEST_CASE_2])
def test_shape(self, input_param, input_data, expected_shape, expected_last):
xform = RandSpatialCropSamplesd(**input_param)
diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py
index 535fb7cb20..0ac5e9482e 100644
--- a/tests/test_rand_std_shift_intensity.py
+++ b/tests/test_rand_std_shift_intensity.py
@@ -22,6 +22,7 @@
class TestRandStdShiftIntensity(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value(self, p):
np.random.seed(0)
diff --git a/tests/test_rand_std_shift_intensityd.py b/tests/test_rand_std_shift_intensityd.py
index 31209ee754..1fd0c5d2a8 100644
--- a/tests/test_rand_std_shift_intensityd.py
+++ b/tests/test_rand_std_shift_intensityd.py
@@ -20,6 +20,7 @@
class TestRandStdShiftIntensityd(NumpyImageTestCase2D):
+
def test_value(self):
for p in TEST_NDARRAYS:
key = "img"
diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py
index 9d37779613..1524442f61 100644
--- a/tests/test_rand_weighted_cropd.py
+++ b/tests/test_rand_weighted_cropd.py
@@ -148,6 +148,7 @@ def get_data(ndim):
class TestRandWeightedCrop(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, expected_centers):
crop = RandWeightedCropd(**init_params)
diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py
index d52b79d8cf..2da04fd652 100644
--- a/tests/test_rand_zoom.py
+++ b/tests/test_rand_zoom.py
@@ -33,6 +33,7 @@
class TestRandZoom(NumpyImageTestCase2D):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, min_zoom, max_zoom, mode, keep_size, align_corners=None):
for p in TEST_NDARRAYS_ALL:
diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py
index bb0495c793..bcbf188310 100644
--- a/tests/test_rand_zoomd.py
+++ b/tests/test_rand_zoomd.py
@@ -31,6 +31,7 @@
class TestRandZoomd(NumpyImageTestCase2D):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size):
key = "img"
diff --git a/tests/test_randidentity.py b/tests/test_randidentity.py
index 09dc055b4e..3a8936f2d2 100644
--- a/tests/test_randidentity.py
+++ b/tests/test_randidentity.py
@@ -19,11 +19,13 @@
class T(mt.Transform):
+
def __call__(self, x):
return x * 2
class TestIdentity(NumpyImageTestCase2D):
+
def test_identity(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
diff --git a/tests/test_random_order.py b/tests/test_random_order.py
index e5507fafca..b38d2398fb 100644
--- a/tests/test_random_order.py
+++ b/tests/test_random_order.py
@@ -30,6 +30,7 @@
class InvC(Inv):
+
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x + 1
@@ -37,6 +38,7 @@ def __init__(self, keys):
class InvD(Inv):
+
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x * 100
@@ -55,6 +57,7 @@ def __init__(self, keys):
class TestRandomOrder(unittest.TestCase):
+
def test_empty_compose(self):
c = RandomOrder()
i = 1
@@ -113,6 +116,7 @@ def test_inverse(self, transform, invertible, use_metatensor):
class TestRandomOrderAPITests(unittest.TestCase):
+
@staticmethod
def data_from_keys(keys):
if keys is None:
diff --git a/tests/test_randomizable.py b/tests/test_randomizable.py
index 96854a6db8..56d5293130 100644
--- a/tests/test_randomizable.py
+++ b/tests/test_randomizable.py
@@ -19,11 +19,13 @@
class RandTest(Randomizable):
+
def randomize(self, data=None):
pass
class TestRandomizable(unittest.TestCase):
+
def test_default(self):
inst = RandTest()
r1 = inst.R.rand()
diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py
index 3a0995be68..919f9299bf 100644
--- a/tests/test_randomizable_transform_type.py
+++ b/tests/test_randomizable_transform_type.py
@@ -21,11 +21,13 @@ class InheritsInterface(RandomizableTrait):
class InheritsImplementation(RandomizableTransform):
+
def __call__(self, data):
return data
class TestRandomizableTransformType(unittest.TestCase):
+
def test_is_randomizable_transform_type(self):
inst = InheritsInterface()
self.assertIsInstance(inst, RandomizableTrait)
diff --git a/tests/test_randtorchvisiond.py b/tests/test_randtorchvisiond.py
index 82f9adf473..7ad06dfd2a 100644
--- a/tests/test_randtorchvisiond.py
+++ b/tests/test_randtorchvisiond.py
@@ -52,6 +52,7 @@
class TestRandTorchVisiond(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py
index 40cd36f31d..fd02e3bdc9 100644
--- a/tests/test_rankfilter_dist.py
+++ b/tests/test_rankfilter_dist.py
@@ -23,6 +23,7 @@
class DistributedRankFilterTest(DistTestCase):
+
def setUp(self):
self.log_dir = tempfile.TemporaryDirectory()
@@ -50,6 +51,7 @@ def tearDown(self) -> None:
class SingleRankFilterTest(unittest.TestCase):
+
def tearDown(self) -> None:
self.log_dir.cleanup()
diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py
index 38adb9617b..1815000777 100644
--- a/tests/test_recon_net_utils.py
+++ b/tests/test_recon_net_utils.py
@@ -49,6 +49,7 @@
class TestReconNetUtils(unittest.TestCase):
+
@parameterized.expand(TEST_RESHAPE)
def test_reshape_channel_complex(self, test_data):
result = reshape_complex_to_channel_dim(test_data)
diff --git a/tests/test_reference_based_normalize_intensity.py b/tests/test_reference_based_normalize_intensity.py
index 8d2715f983..2d946af118 100644
--- a/tests/test_reference_based_normalize_intensity.py
+++ b/tests/test_reference_based_normalize_intensity.py
@@ -52,6 +52,7 @@
class TestDetailedNormalizeIntensityd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_target_mean_std(self, args, data, normalized_data, normalized_target, mean, std):
dtype = data[args["keys"][0]].dtype
diff --git a/tests/test_reference_based_spatial_cropd.py b/tests/test_reference_based_spatial_cropd.py
index d5777482c0..83cd9c4a5d 100644
--- a/tests/test_reference_based_spatial_cropd.py
+++ b/tests/test_reference_based_spatial_cropd.py
@@ -46,6 +46,7 @@
class TestTargetBasedSpatialCropd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, args, data, expected_shape):
cropper = ReferenceBasedSpatialCropd(keys=args["keys"], ref_key=args["ref_key"])
diff --git a/tests/test_reference_resolver.py b/tests/test_reference_resolver.py
index 07d56a16df..1f02bb01a7 100644
--- a/tests/test_reference_resolver.py
+++ b/tests/test_reference_resolver.py
@@ -70,6 +70,7 @@
class TestReferenceResolver(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else []))
def test_resolve(self, configs, expected_id, output_type):
locator = ComponentLocator()
diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py
index 6cd973c32e..e8f82eb0c2 100644
--- a/tests/test_reg_loss_integration.py
+++ b/tests/test_reg_loss_integration.py
@@ -32,6 +32,7 @@
class TestRegLossIntegration(unittest.TestCase):
+
def setUp(self):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@@ -61,6 +62,7 @@ def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1):
# define a one layer model
class OnelayerNet(nn.Module):
+
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
diff --git a/tests/test_regunet.py b/tests/test_regunet.py
index 04ff60ef30..3100d7660c 100644
--- a/tests/test_regunet.py
+++ b/tests/test_regunet.py
@@ -63,6 +63,7 @@
class TestREGUNET(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_REGUNET_2D + TEST_CASE_REGUNET_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = RegUNet(**input_param).to(device)
diff --git a/tests/test_regunet_block.py b/tests/test_regunet_block.py
index eebe9d8694..fa07671d03 100644
--- a/tests/test_regunet_block.py
+++ b/tests/test_regunet_block.py
@@ -65,6 +65,7 @@
class TestRegistrationResidualConvBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_RESIDUAL)
def test_shape(self, input_param, input_shape, expected_shape):
net = RegistrationResidualConvBlock(**input_param)
@@ -74,6 +75,7 @@ def test_shape(self, input_param, input_shape, expected_shape):
class TestRegistrationDownSampleBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_DOWN_SAMPLE)
def test_shape(self, input_param, input_shape, expected_shape):
net = RegistrationDownSampleBlock(**input_param)
@@ -88,6 +90,7 @@ def test_ill_shape(self):
class TestRegistrationExtractionBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_EXTRACTION)
def test_shape(self, input_param, input_shapes, image_size, expected_shape):
net = RegistrationExtractionBlock(**input_param)
diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py
index 90b1b79b03..7da00ee75d 100644
--- a/tests/test_remove_repeated_channel.py
+++ b/tests/test_remove_repeated_channel.py
@@ -24,6 +24,7 @@
class TestRemoveRepeatedChannel(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_shape):
result = RemoveRepeatedChannel(**input_param)(input_data)
diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py
index 6d36d32f6f..08ec7fb44c 100644
--- a/tests/test_remove_repeated_channeld.py
+++ b/tests/test_remove_repeated_channeld.py
@@ -34,6 +34,7 @@
class TestRemoveRepeatedChanneld(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, input_data, expected_shape):
result = RemoveRepeatedChanneld(**input_param)(input_data)
diff --git a/tests/test_remove_small_objects.py b/tests/test_remove_small_objects.py
index 200f4ed9b2..633a6d9a99 100644
--- a/tests/test_remove_small_objects.py
+++ b/tests/test_remove_small_objects.py
@@ -55,6 +55,7 @@
@SkipIfNoModule("skimage.morphology")
class TestRemoveSmallObjects(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_remove_small_objects(self, dtype, im_type, lbl, expected, params=None):
params = params or {}
diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py
index 0ae5743836..82d1d92bd2 100644
--- a/tests/test_repeat_channel.py
+++ b/tests/test_repeat_channel.py
@@ -24,6 +24,7 @@
class TestRepeatChannel(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, input_data, expected_shape):
result = RepeatChannel(**input_param)(input_data)
diff --git a/tests/test_repeat_channeld.py b/tests/test_repeat_channeld.py
index 9f7872135d..2be13a08d1 100644
--- a/tests/test_repeat_channeld.py
+++ b/tests/test_repeat_channeld.py
@@ -31,6 +31,7 @@
class TestRepeatChanneld(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, input_data, expected_shape):
result = RepeatChanneld(**input_param)(input_data)
diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py
index cac3fd39e5..f3964ac65d 100644
--- a/tests/test_replace_module.py
+++ b/tests/test_replace_module.py
@@ -32,6 +32,7 @@
class TestReplaceModule(unittest.TestCase):
+
def setUp(self):
self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
self.num_relus = self.get_num_modules(torch.nn.ReLU)
diff --git a/tests/test_require_pkg.py b/tests/test_require_pkg.py
index b1a3d82a17..065a7509a4 100644
--- a/tests/test_require_pkg.py
+++ b/tests/test_require_pkg.py
@@ -17,7 +17,9 @@
class TestRequirePkg(unittest.TestCase):
+
def test_class(self):
+
@require_pkg(pkg_name="torch", version="1.4", version_checker=min_version)
class TestClass:
pass
@@ -25,6 +27,7 @@ class TestClass:
TestClass()
def test_function(self):
+
@require_pkg(pkg_name="torch", version="1.4", version_checker=min_version)
def test_func(x):
return x
@@ -32,6 +35,7 @@ def test_func(x):
test_func(x=None)
def test_warning(self):
+
@require_pkg(pkg_name="test123", raise_error=False)
def test_func(x):
return x
diff --git a/tests/test_resample.py b/tests/test_resample.py
index c90dc5f13d..68b08b8b87 100644
--- a/tests/test_resample.py
+++ b/tests/test_resample.py
@@ -35,6 +35,7 @@ def rotate_90_2d():
class TestResampleFunction(unittest.TestCase):
+
@parameterized.expand(RESAMPLE_FUNCTION_CASES)
def test_resample_function_impl(self, img, matrix, expected):
out = resample(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"})
diff --git a/tests/test_resample_backends.py b/tests/test_resample_backends.py
index 97ee0731e8..7ddd9c7ec2 100644
--- a/tests/test_resample_backends.py
+++ b/tests/test_resample_backends.py
@@ -44,6 +44,7 @@
@SkipIfBeforePyTorchVersion((1, 9, 1))
class TestResampleBackends(unittest.TestCase):
+
@parameterized.expand(TEST_IDENTITY)
def test_resample_identity(self, input_param, im_type, interp, padding, input_shape):
"""test resampling of an identity grid with padding 2, im_type, interp, padding, input_shape"""
diff --git a/tests/test_resample_datalist.py b/tests/test_resample_datalist.py
index ae52492953..ac5cb25bb3 100644
--- a/tests/test_resample_datalist.py
+++ b/tests/test_resample_datalist.py
@@ -32,6 +32,7 @@
class TestResampleDatalist(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value_shape(self, input_param, expected):
result = resample_datalist(**input_param)
diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py
index b12ffd04be..f0d34547a7 100644
--- a/tests/test_resample_to_match.py
+++ b/tests/test_resample_to_match.py
@@ -46,6 +46,7 @@ def get_rand_fname(len=10, suffix=".nii.gz"):
@unittest.skipUnless(has_itk, "itk not installed")
class TestResampleToMatch(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py
index 748e830bdd..9d104bf392 100644
--- a/tests/test_resample_to_matchd.py
+++ b/tests/test_resample_to_matchd.py
@@ -36,6 +36,7 @@ def update_fname(d):
class TestResampleToMatchd(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
diff --git a/tests/test_resampler.py b/tests/test_resampler.py
index 50ea344090..af0db657aa 100644
--- a/tests/test_resampler.py
+++ b/tests/test_resampler.py
@@ -152,6 +152,7 @@
class TestResample(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_resample(self, input_param, input_data, expected_val):
g = Resample(**input_param)
diff --git a/tests/test_resize.py b/tests/test_resize.py
index 97a8f8dab2..33abfe4e1f 100644
--- a/tests/test_resize.py
+++ b/tests/test_resize.py
@@ -39,6 +39,7 @@
class TestResize(NumpyImageTestCase2D):
+
def test_invalid_inputs(self):
with self.assertRaises(ValueError):
resize = Resize(spatial_size=(128, 128, 3), mode="order")
diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py
index 287df039b8..daf257f89f 100644
--- a/tests/test_resize_with_pad_or_crop.py
+++ b/tests/test_resize_with_pad_or_crop.py
@@ -48,6 +48,7 @@
class TestResizeWithPadOrCrop(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_pad_shape(self, input_param, input_shape, expected_shape, _):
for p in TEST_NDARRAYS_ALL:
diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py
index 471144a609..391e0feb22 100644
--- a/tests/test_resize_with_pad_or_cropd.py
+++ b/tests/test_resize_with_pad_or_cropd.py
@@ -46,6 +46,7 @@
class TestResizeWithPadOrCropd(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_pad_shape(self, input_param, input_data, expected_val):
for p in TEST_NDARRAYS_ALL:
diff --git a/tests/test_resized.py b/tests/test_resized.py
index bd711b33d8..ab4c9815ea 100644
--- a/tests/test_resized.py
+++ b/tests/test_resized.py
@@ -59,6 +59,7 @@
class TestResized(NumpyImageTestCase2D):
+
def test_invalid_inputs(self):
with self.assertRaises(ValueError):
resize = Resized(keys="img", spatial_size=(128, 128, 3), mode="order")
diff --git a/tests/test_resnet.py b/tests/test_resnet.py
index 15ec6353f9..ad1aad8fc6 100644
--- a/tests/test_resnet.py
+++ b/tests/test_resnet.py
@@ -192,6 +192,7 @@
class TestResNet(unittest.TestCase):
+
def setUp(self):
self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth")
diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py
index 074a5b63fa..f36708d5b3 100644
--- a/tests/test_retinanet.py
+++ b/tests/test_retinanet.py
@@ -101,6 +101,7 @@
@unittest.skipUnless(has_torchvision, "Requires torchvision")
@skip_if_quick
class TestRetinaNet(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_retina_shape(self, model, input_param, input_shape):
backbone = model(**input_param)
diff --git a/tests/test_retinanet_detector.py b/tests/test_retinanet_detector.py
index 7292bc0c49..691254fd87 100644
--- a/tests/test_retinanet_detector.py
+++ b/tests/test_retinanet_detector.py
@@ -93,6 +93,7 @@
class NaiveNetwork(torch.nn.Module):
+
def __init__(self, spatial_dims, num_classes, **kwargs):
super().__init__()
self.spatial_dims = spatial_dims
@@ -114,6 +115,7 @@ def forward(self, images):
@unittest.skipUnless(has_torchvision, "Requires torchvision")
@skip_if_quick
class TestRetinaNetDetector(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape):
returned_layers = [1]
diff --git a/tests/test_retinanet_predict_utils.py b/tests/test_retinanet_predict_utils.py
index d97806e91c..d909699469 100644
--- a/tests/test_retinanet_predict_utils.py
+++ b/tests/test_retinanet_predict_utils.py
@@ -85,6 +85,7 @@
class NaiveNetwork(torch.nn.Module):
+
def __init__(self, spatial_dims, num_classes, **kwargs):
super().__init__()
self.spatial_dims = spatial_dims
@@ -103,6 +104,7 @@ def forward(self, images):
class NaiveNetwork2(torch.nn.Module):
+
def __init__(self, spatial_dims, num_classes, **kwargs):
super().__init__()
self.spatial_dims = spatial_dims
@@ -121,6 +123,7 @@ def forward(self, images):
class TestPredictor(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_naive_predictor(self, input_param, input_shape):
net = NaiveNetwork(**input_param)
diff --git a/tests/test_rotate.py b/tests/test_rotate.py
index 95c63e65f7..19fbd1409f 100644
--- a/tests/test_rotate.py
+++ b/tests/test_rotate.py
@@ -52,6 +52,7 @@
class TestRotate2D(NumpyImageTestCase2D):
+
@parameterized.expand(TEST_CASES_2D)
def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):
init_param = {
@@ -90,6 +91,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al
class TestRotate3D(NumpyImageTestCase3D):
+
@parameterized.expand(TEST_CASES_3D)
def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):
init_param = {
diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py
index 0948469df9..ebc3fba7e0 100644
--- a/tests/test_rotate90.py
+++ b/tests/test_rotate90.py
@@ -31,6 +31,7 @@
class TestRotate90(NumpyImageTestCase2D):
+
def test_rotate90_default(self):
rotate = Rotate90()
for p in TEST_NDARRAYS_ALL:
@@ -102,6 +103,7 @@ def test_prob_k_spatial_axes(self):
class TestRotate903d(NumpyImageTestCase3D):
+
def test_rotate90_default(self):
rotate = Rotate90()
for p in TEST_NDARRAYS_ALL:
@@ -169,6 +171,7 @@ def test_prob_k_spatial_axes(self):
@unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.")
class TestRot90Consistency(unittest.TestCase):
+
@parameterized.expand([[2], [3], [4]])
def test_affine_rot90(self, s):
"""s"""
diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py
index 08d3a97498..ffe920992a 100644
--- a/tests/test_rotate90d.py
+++ b/tests/test_rotate90d.py
@@ -22,6 +22,7 @@
class TestRotate90d(NumpyImageTestCase2D):
+
def test_rotate90_default(self):
key = "test"
rotate = Rotate90d(keys=key)
diff --git a/tests/test_rotated.py b/tests/test_rotated.py
index 3755ab1344..28ca755661 100644
--- a/tests/test_rotated.py
+++ b/tests/test_rotated.py
@@ -43,6 +43,7 @@
@unittest.skipIf(USE_COMPILED, "unittests are not designed for both USE_COMPILED=True/False")
class TestRotated2D(NumpyImageTestCase2D):
+
@parameterized.expand(TEST_CASES_2D)
def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):
init_param = {
@@ -94,6 +95,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al
@unittest.skipIf(USE_COMPILED, "unittests are not designed for both USE_COMPILED=True/False")
class TestRotated3D(NumpyImageTestCase3D):
+
@parameterized.expand(TEST_CASES_3D)
def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):
init_param = {
@@ -143,6 +145,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al
@unittest.skipIf(USE_COMPILED, "unittests are not designed for both USE_COMPILED=True/False")
class TestRotated3DXY(NumpyImageTestCase3D):
+
@parameterized.expand(TEST_CASES_3D)
def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):
rotate_fn = Rotated(
diff --git a/tests/test_safe_dtype_range.py b/tests/test_safe_dtype_range.py
index 73f9607d7d..61b55635ae 100644
--- a/tests/test_safe_dtype_range.py
+++ b/tests/test_safe_dtype_range.py
@@ -54,6 +54,7 @@
class TesSafeDtypeRange(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_safe_dtype_range(self, in_image, im_out, out_dtype):
result = safe_dtype_range(in_image, out_dtype)
diff --git a/tests/test_saliency_inferer.py b/tests/test_saliency_inferer.py
index 4efe30d7a6..70ec048d1c 100644
--- a/tests/test_saliency_inferer.py
+++ b/tests/test_saliency_inferer.py
@@ -28,6 +28,7 @@
class TestSaliencyInferer(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, cam_name):
model = DenseNet(
diff --git a/tests/test_sample_slices.py b/tests/test_sample_slices.py
index 02b7926392..a183689970 100644
--- a/tests/test_sample_slices.py
+++ b/tests/test_sample_slices.py
@@ -32,6 +32,7 @@
class TestSampleSlices(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_shape(self, input_data, dim, as_indices, vals, expected_result):
for p in TEST_NDARRAYS:
diff --git a/tests/test_sampler_dist.py b/tests/test_sampler_dist.py
index b2f86c54cc..b8bd1c7a9f 100644
--- a/tests/test_sampler_dist.py
+++ b/tests/test_sampler_dist.py
@@ -24,6 +24,7 @@
class DistributedSamplerTest(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_even(self):
data = [1, 2, 3, 4, 5]
diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py
index dd0b213bd6..9a7d4fc3f5 100644
--- a/tests/test_save_classificationd.py
+++ b/tests/test_save_classificationd.py
@@ -26,6 +26,7 @@
class TestSaveClassificationd(unittest.TestCase):
+
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:
data = [
diff --git a/tests/test_save_image.py b/tests/test_save_image.py
index d88db201ce..ed7061095d 100644
--- a/tests/test_save_image.py
+++ b/tests/test_save_image.py
@@ -42,6 +42,7 @@
@unittest.skipUnless(has_itk, "itk not installed")
class TestSaveImage(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_saved_content(self, test_data, meta_data, output_ext, resample):
if meta_data is not None:
diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py
index ab0b9c0d9f..d2095a7554 100644
--- a/tests/test_save_imaged.py
+++ b/tests/test_save_imaged.py
@@ -54,6 +54,7 @@
@unittest.skipUnless(has_itk, "itk not installed")
class TestSaveImaged(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_saved_content(self, test_data, output_ext, resample):
with tempfile.TemporaryDirectory() as tempdir:
@@ -73,7 +74,9 @@ def test_saved_content(self, test_data, output_ext, resample):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_custom_folderlayout(self, test_data, output_ext, resample):
+
class TestFolderLayout(FolderLayoutBase):
+
def __init__(self, basepath: Path, extension: str, makedirs: bool):
self.basepath = basepath
self.ext = extension
diff --git a/tests/test_save_state.py b/tests/test_save_state.py
index 8ab7080700..0581a3ce1f 100644
--- a/tests/test_save_state.py
+++ b/tests/test_save_state.py
@@ -43,6 +43,7 @@
class TestSaveState(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_1,
diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py
index b7f89cdfde..7c60287e2d 100644
--- a/tests/test_savitzky_golay_filter.py
+++ b/tests/test_savitzky_golay_filter.py
@@ -100,6 +100,7 @@
class TestSavitzkyGolayCPU(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH]
)
@@ -109,6 +110,7 @@ def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
class TestSavitzkyGolayCPUREP(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]
)
@@ -119,6 +121,7 @@ def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
@skip_if_no_cuda
class TestSavitzkyGolayGPU(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH]
)
@@ -129,6 +132,7 @@ def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):
@skip_if_no_cuda
class TestSavitzkyGolayGPUREP(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]
)
diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py
index 6da4f24c62..14e403e238 100644
--- a/tests/test_savitzky_golay_smooth.py
+++ b/tests/test_savitzky_golay_smooth.py
@@ -60,6 +60,7 @@
class TestSavitzkyGolaySmooth(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP]
)
diff --git a/tests/test_savitzky_golay_smoothd.py b/tests/test_savitzky_golay_smoothd.py
index 7e7176e2bb..3bb4056046 100644
--- a/tests/test_savitzky_golay_smoothd.py
+++ b/tests/test_savitzky_golay_smoothd.py
@@ -60,6 +60,7 @@
class TestSavitzkyGolaySmoothd(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP]
)
diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py
index 57a7da1780..17dfe305b2 100644
--- a/tests/test_scale_intensity.py
+++ b/tests/test_scale_intensity.py
@@ -22,6 +22,7 @@
class TestScaleIntensity(NumpyImageTestCase2D):
+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_range_scale(self, p):
scaler = ScaleIntensity(minv=1.0, maxv=2.0)
diff --git a/tests/test_scale_intensity_fixed_mean.py b/tests/test_scale_intensity_fixed_mean.py
index afbcd46141..35d38ef0b1 100644
--- a/tests/test_scale_intensity_fixed_mean.py
+++ b/tests/test_scale_intensity_fixed_mean.py
@@ -21,6 +21,7 @@
class TestScaleIntensityFixedMean(NumpyImageTestCase2D):
+
def test_factor_scale(self):
for p in TEST_NDARRAYS:
scaler = ScaleIntensityFixedMean(factor=0.1, fixed_mean=False)
diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py
index 898f4dfb45..6013a237db 100644
--- a/tests/test_scale_intensity_range.py
+++ b/tests/test_scale_intensity_range.py
@@ -20,6 +20,7 @@
class IntensityScaleIntensityRange(NumpyImageTestCase2D):
+
def test_image_scale_intensity_range(self):
scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80, dtype=np.uint8)
for p in TEST_NDARRAYS:
diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py
index 583dcec07e..7c3a684a00 100644
--- a/tests/test_scale_intensity_range_percentiles.py
+++ b/tests/test_scale_intensity_range_percentiles.py
@@ -20,6 +20,7 @@
class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D):
+
def test_scaling(self):
img = self.imt[0]
lower = 10
diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py
index 8e2511d9e4..ab0347fbbf 100644
--- a/tests/test_scale_intensity_range_percentilesd.py
+++ b/tests/test_scale_intensity_range_percentilesd.py
@@ -20,6 +20,7 @@
class TestScaleIntensityRangePercentilesd(NumpyImageTestCase2D):
+
def test_scaling(self):
img = self.imt
lower = 10
diff --git a/tests/test_scale_intensity_ranged.py b/tests/test_scale_intensity_ranged.py
index 724acf1c73..cc3f1220e7 100644
--- a/tests/test_scale_intensity_ranged.py
+++ b/tests/test_scale_intensity_ranged.py
@@ -18,6 +18,7 @@
class IntensityScaleIntensityRanged(NumpyImageTestCase2D):
+
def test_image_scale_intensity_ranged(self):
key = "img"
scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80)
diff --git a/tests/test_scale_intensityd.py b/tests/test_scale_intensityd.py
index 6705cfda9d..88beece894 100644
--- a/tests/test_scale_intensityd.py
+++ b/tests/test_scale_intensityd.py
@@ -20,6 +20,7 @@
class TestScaleIntensityd(NumpyImageTestCase2D):
+
def test_range_scale(self):
key = "img"
for p in TEST_NDARRAYS:
diff --git a/tests/test_se_block.py b/tests/test_se_block.py
index de129f4d55..ca60643635 100644
--- a/tests/test_se_block.py
+++ b/tests/test_se_block.py
@@ -63,6 +63,7 @@
class TestSEBlockLayer(unittest.TestCase):
+
@parameterized.expand(TEST_CASES + TEST_CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = SEBlock(**input_param).to(device)
diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py
index c97e459f50..c1e72749cc 100644
--- a/tests/test_se_blocks.py
+++ b/tests/test_se_blocks.py
@@ -41,6 +41,7 @@
class TestChannelSELayer(unittest.TestCase):
+
@parameterized.expand(TEST_CASES + TEST_CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = ChannelSELayer(**input_param)
@@ -60,6 +61,7 @@ def test_ill_arg(self):
class TestResidualSELayer(unittest.TestCase):
+
@parameterized.expand(TEST_CASES[:1])
def test_shape(self, input_param, input_shape, expected_shape):
net = ResidualSELayer(**input_param)
diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py
index 23bc63fbf6..6713e7bba9 100644
--- a/tests/test_seg_loss_integration.py
+++ b/tests/test_seg_loss_integration.py
@@ -47,6 +47,7 @@
class TestSegLossIntegration(unittest.TestCase):
+
def setUp(self):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@@ -92,6 +93,7 @@ def test_convergence(self, loss_type, loss_args, forward_args):
# define a one layer model
class OnelayerNet(nn.Module):
+
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(num_voxels, 200)
diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py
index cb34445efa..728699c434 100644
--- a/tests/test_segresnet.py
+++ b/tests/test_segresnet.py
@@ -83,6 +83,7 @@
class TestResNet(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_SEGRESNET + TEST_CASE_SEGRESNET_2)
def test_shape(self, input_param, input_shape, expected_shape):
net = SegResNet(**input_param).to(device)
@@ -102,6 +103,7 @@ def test_script(self):
class TestResNetVAE(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_SEGRESNET_VAE)
def test_vae_shape(self, input_param, input_shape, expected_shape):
net = SegResNetVAE(**input_param).to(device)
diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py
index 343f39d72c..633507a06a 100644
--- a/tests/test_segresnet_block.py
+++ b/tests/test_segresnet_block.py
@@ -38,6 +38,7 @@
class TestResBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_RESBLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
net = ResBlock(**input_param)
diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py
index a5b88f9724..5372fcc8ae 100644
--- a/tests/test_segresnet_ds.py
+++ b/tests/test_segresnet_ds.py
@@ -72,6 +72,7 @@
class TestResNetDS(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_SEGRESNET_DS)
def test_shape(self, input_param, input_shape, expected_shape):
net = SegResNetDS(**input_param).to(device)
diff --git a/tests/test_select_cross_validation_folds.py b/tests/test_select_cross_validation_folds.py
index 3ab6c0a9c5..c7d19f34ab 100644
--- a/tests/test_select_cross_validation_folds.py
+++ b/tests/test_select_cross_validation_folds.py
@@ -43,6 +43,7 @@
class TestSelectCrossValidationFolds(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_value(self, input_param, result):
partitions = partition_dataset(**input_param)
diff --git a/tests/test_select_itemsd.py b/tests/test_select_itemsd.py
index 5eb4a1c51b..f025917b9d 100644
--- a/tests/test_select_itemsd.py
+++ b/tests/test_select_itemsd.py
@@ -23,6 +23,7 @@
class TestSelectItemsd(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1])
def test_memory(self, input_param, expected_key_size):
input_data = {}
diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py
index 6062b5352f..b8be4fd1b6 100644
--- a/tests/test_selfattention.py
+++ b/tests/test_selfattention.py
@@ -37,6 +37,7 @@
class TestResBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_SABLOCK)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
diff --git a/tests/test_senet.py b/tests/test_senet.py
index 92b5f39ace..6809d4562b 100644
--- a/tests/test_senet.py
+++ b/tests/test_senet.py
@@ -58,6 +58,7 @@
class TestSENET(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_senet_shape(self, net, net_args):
input_data = torch.randn(2, 2, 64, 64, 64).to(device)
@@ -75,6 +76,7 @@ def test_script(self, net, net_args):
class TestPretrainedSENET(unittest.TestCase):
+
def setUp(self):
self.original_urls = se_mod.SE_NET_MODELS.copy()
replace_url = test_is_quick()
diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py
index 1797a649e0..d712f05ee1 100644
--- a/tests/test_separable_filter.py
+++ b/tests/test_separable_filter.py
@@ -20,6 +20,7 @@
class SeparableFilterTestCase(unittest.TestCase):
+
def test_1d(self):
a = torch.tensor([[list(range(10))]], dtype=torch.float)
out = separable_filtering(a, torch.tensor([-1, 0, 1]))
diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py
index aab7af1079..7d64aed244 100644
--- a/tests/test_set_determinism.py
+++ b/tests/test_set_determinism.py
@@ -21,6 +21,7 @@
class TestSetDeterminism(unittest.TestCase):
+
def test_values(self):
# check system default flags
set_determinism(None)
@@ -55,6 +56,7 @@ def test_values(self):
class TestSetFlag(unittest.TestCase):
+
def setUp(self):
set_determinism(1, use_deterministic_algorithms=True)
diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py
index 993e8a4ac2..7860656b3d 100644
--- a/tests/test_set_visible_devices.py
+++ b/tests/test_set_visible_devices.py
@@ -18,6 +18,7 @@
class TestVisibleDevices(unittest.TestCase):
+
@staticmethod
def run_process_and_get_exit_code(code_to_execute):
value = os.system(code_to_execute)
diff --git a/tests/test_shift_intensity.py b/tests/test_shift_intensity.py
index f1bc36036e..90aa0f9271 100644
--- a/tests/test_shift_intensity.py
+++ b/tests/test_shift_intensity.py
@@ -20,6 +20,7 @@
class TestShiftIntensity(NumpyImageTestCase2D):
+
def test_value(self):
shifter = ShiftIntensity(offset=1.0)
result = shifter(self.imt)
diff --git a/tests/test_shift_intensityd.py b/tests/test_shift_intensityd.py
index e8d163b34a..22336b4415 100644
--- a/tests/test_shift_intensityd.py
+++ b/tests/test_shift_intensityd.py
@@ -21,6 +21,7 @@
class TestShiftIntensityd(NumpyImageTestCase2D):
+
def test_value(self):
key = "img"
for p in TEST_NDARRAYS:
diff --git a/tests/test_shuffle_buffer.py b/tests/test_shuffle_buffer.py
index 9fcd3a23f6..e75321616b 100644
--- a/tests/test_shuffle_buffer.py
+++ b/tests/test_shuffle_buffer.py
@@ -23,6 +23,7 @@
@SkipIfBeforePyTorchVersion((1, 12))
class TestShuffleBuffer(unittest.TestCase):
+
def test_shape(self):
buffer = ShuffleBuffer([1, 2, 3, 4], seed=0)
num_workers = 2 if sys.platform == "linux" else 0
diff --git a/tests/test_signal_continuouswavelet.py b/tests/test_signal_continuouswavelet.py
index 4886168a00..7e6ee8b105 100644
--- a/tests/test_signal_continuouswavelet.py
+++ b/tests/test_signal_continuouswavelet.py
@@ -29,6 +29,7 @@
@skipUnless(has_pywt, "pywt required")
class TestSignalContinousWavelet(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, type, length, frequency):
self.assertIsInstance(SignalContinuousWavelet(type, length, frequency), SignalContinuousWavelet)
diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py
index ee606d960c..a3ee623cc5 100644
--- a/tests/test_signal_fillempty.py
+++ b/tests/test_signal_fillempty.py
@@ -26,6 +26,7 @@
@SkipIfBeforePyTorchVersion((1, 9))
class TestSignalFillEmptyNumpy(unittest.TestCase):
+
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)
sig = np.load(TEST_SIGNAL)
@@ -37,6 +38,7 @@ def test_correct_parameters_multi_channels(self):
@SkipIfBeforePyTorchVersion((1, 9))
class TestSignalFillEmptyTorch(unittest.TestCase):
+
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)
sig = convert_to_tensor(np.load(TEST_SIGNAL))
diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py
index 5b12055e7d..ee8c571ef8 100644
--- a/tests/test_signal_fillemptyd.py
+++ b/tests/test_signal_fillemptyd.py
@@ -26,6 +26,7 @@
@SkipIfBeforePyTorchVersion((1, 9))
class TestSignalFillEmptyNumpy(unittest.TestCase):
+
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd)
sig = np.load(TEST_SIGNAL)
@@ -41,6 +42,7 @@ def test_correct_parameters_multi_channels(self):
@SkipIfBeforePyTorchVersion((1, 9))
class TestSignalFillEmptyTorch(unittest.TestCase):
+
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd)
sig = convert_to_tensor(np.load(TEST_SIGNAL))
diff --git a/tests/test_signal_rand_add_gaussiannoise.py b/tests/test_signal_rand_add_gaussiannoise.py
index 2090df876f..e5c9eba8a2 100644
--- a/tests/test_signal_rand_add_gaussiannoise.py
+++ b/tests/test_signal_rand_add_gaussiannoise.py
@@ -25,6 +25,7 @@
class TestSignalRandAddGaussianNoiseNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries):
self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise)
@@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries):
class TestSignalRandAddGaussianNoiseTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries):
self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise)
diff --git a/tests/test_signal_rand_add_sine.py b/tests/test_signal_rand_add_sine.py
index ae0684d608..4ba91247dd 100644
--- a/tests/test_signal_rand_add_sine.py
+++ b/tests/test_signal_rand_add_sine.py
@@ -25,6 +25,7 @@
class TestSignalRandAddSineNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, freqs):
self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine)
@@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries, freqs):
class TestSignalRandAddSineTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, freqs):
self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine)
diff --git a/tests/test_signal_rand_add_sine_partial.py b/tests/test_signal_rand_add_sine_partial.py
index 109fb006ea..71b67747a2 100644
--- a/tests/test_signal_rand_add_sine_partial.py
+++ b/tests/test_signal_rand_add_sine_partial.py
@@ -25,6 +25,7 @@
class TestSignalRandAddSinePartialNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial)
@@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fracti
class TestSignalRandAddSinePartialTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial)
diff --git a/tests/test_signal_rand_add_squarepulse.py b/tests/test_signal_rand_add_squarepulse.py
index efbdc9af09..e1432029ea 100644
--- a/tests/test_signal_rand_add_squarepulse.py
+++ b/tests/test_signal_rand_add_squarepulse.py
@@ -31,6 +31,7 @@
@skipUnless(has_scipy, "scipy required")
@SkipIfBeforePyTorchVersion((1, 10, 1))
class TestSignalRandAddSquarePulseNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, frequencies):
self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse)
@@ -43,6 +44,7 @@ def test_correct_parameters_multi_channels(self, boundaries, frequencies):
@skipUnless(has_scipy, "scipy required")
@SkipIfBeforePyTorchVersion((1, 10, 1))
class TestSignalRandAddSquarePulseTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, frequencies):
self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse)
diff --git a/tests/test_signal_rand_add_squarepulse_partial.py b/tests/test_signal_rand_add_squarepulse_partial.py
index eee3f5596d..7e1c2bb9d8 100644
--- a/tests/test_signal_rand_add_squarepulse_partial.py
+++ b/tests/test_signal_rand_add_squarepulse_partial.py
@@ -31,6 +31,7 @@
@skipUnless(has_scipy, "scipy required")
@SkipIfBeforePyTorchVersion((1, 10, 1))
class TestSignalRandAddSquarePulsePartialNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
self.assertIsInstance(
@@ -45,6 +46,7 @@ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fracti
@skipUnless(has_scipy, "scipy required")
@SkipIfBeforePyTorchVersion((1, 10, 1))
class TestSignalRandAddSquarePulsePartialTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
self.assertIsInstance(
diff --git a/tests/test_signal_rand_drop.py b/tests/test_signal_rand_drop.py
index 5dcd466481..bf2db75a6a 100644
--- a/tests/test_signal_rand_drop.py
+++ b/tests/test_signal_rand_drop.py
@@ -25,6 +25,7 @@
class TestSignalRandDropNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries):
self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop)
@@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries):
class TestSignalRandDropTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries):
self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop)
diff --git a/tests/test_signal_rand_scale.py b/tests/test_signal_rand_scale.py
index 126d7cca65..c040c59a1f 100644
--- a/tests/test_signal_rand_scale.py
+++ b/tests/test_signal_rand_scale.py
@@ -25,6 +25,7 @@
class TestSignalRandScaleNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries):
self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale)
@@ -35,6 +36,7 @@ def test_correct_parameters_multi_channels(self, boundaries):
class TestSignalRandScaleTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, boundaries):
self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale)
diff --git a/tests/test_signal_rand_shift.py b/tests/test_signal_rand_shift.py
index ed25cc8b1f..96809e7446 100644
--- a/tests/test_signal_rand_shift.py
+++ b/tests/test_signal_rand_shift.py
@@ -29,6 +29,7 @@
@skipUnless(has_scipy, "scipy required")
class TestSignalRandShiftNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, mode, filling, boundaries):
self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift)
@@ -40,6 +41,7 @@ def test_correct_parameters_multi_channels(self, mode, filling, boundaries):
@skipUnless(has_scipy, "scipy required")
class TestSignalRandShiftTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, mode, filling, boundaries):
self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift)
diff --git a/tests/test_signal_remove_frequency.py b/tests/test_signal_remove_frequency.py
index b18de36c08..9f795ce68b 100644
--- a/tests/test_signal_remove_frequency.py
+++ b/tests/test_signal_remove_frequency.py
@@ -31,6 +31,7 @@
@skipUnless(has_scipy and has_torchaudio, "scipy and torchaudio are required")
class TestSignalRemoveFrequencyNumpy(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq):
self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency)
@@ -49,6 +50,7 @@ def test_correct_parameters_multi_channels(self, frequency, quality_factor, samp
@skipUnless(has_scipy and has_torchaudio, "scipy and torchaudio are required")
class TestSignalRemoveFrequencyTorch(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq):
self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency)
diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py
index f18b208e9c..da7540d45e 100644
--- a/tests/test_simple_aspp.py
+++ b/tests/test_simple_aspp.py
@@ -69,6 +69,7 @@
class TestChannelSELayer(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = SimpleASPP(**input_param)
diff --git a/tests/test_simulatedelay.py b/tests/test_simulatedelay.py
index 5cf47b245e..0a4f23450a 100644
--- a/tests/test_simulatedelay.py
+++ b/tests/test_simulatedelay.py
@@ -22,6 +22,7 @@
class TestSimulateDelay(NumpyImageTestCase2D):
+
@parameterized.expand([(0.45,), (1,)])
def test_value(self, delay_test_time: float):
resize = SimulateDelay(delay_time=delay_test_time)
diff --git a/tests/test_simulatedelayd.py b/tests/test_simulatedelayd.py
index 827fe69510..419e21f24d 100644
--- a/tests/test_simulatedelayd.py
+++ b/tests/test_simulatedelayd.py
@@ -22,6 +22,7 @@
class TestSimulateDelay(NumpyImageTestCase2D):
+
@parameterized.expand([(0.45,), (1,)])
def test_value(self, delay_test_time: float):
resize = SimulateDelayd(keys="imgd", delay_time=delay_test_time)
diff --git a/tests/test_skip_connection.py b/tests/test_skip_connection.py
index 0ac8ef0d7a..5ee166cf10 100644
--- a/tests/test_skip_connection.py
+++ b/tests/test_skip_connection.py
@@ -31,6 +31,7 @@
class TestSkipConnection(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = SkipConnection(submodule=torch.nn.Softmax(dim=1), **input_param)
diff --git a/tests/test_slice_inferer.py b/tests/test_slice_inferer.py
index 4d7dea026f..526542943e 100644
--- a/tests/test_slice_inferer.py
+++ b/tests/test_slice_inferer.py
@@ -23,6 +23,7 @@
class TestSliceInferer(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, spatial_dim):
spatial_dim = int(spatial_dim)
diff --git a/tests/test_sliding_patch_wsi_dataset.py b/tests/test_sliding_patch_wsi_dataset.py
index 518e94552f..6369613426 100644
--- a/tests/test_sliding_patch_wsi_dataset.py
+++ b/tests/test_sliding_patch_wsi_dataset.py
@@ -213,6 +213,7 @@ def setUpModule():
class SlidingPatchWSIDatasetTests:
+
class Tests(unittest.TestCase):
backend = None
@@ -252,6 +253,7 @@ def test_read_patches_large(self, input_parameters, expected):
@skipUnless(has_cucim, "Requires cucim")
class TestSlidingPatchWSIDatasetCuCIM(SlidingPatchWSIDatasetTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "cucim"
@@ -259,6 +261,7 @@ def setUpClass(cls):
@skipUnless(has_osl, "Requires openslide")
class TestSlidingPatchWSIDatasetOpenSlide(SlidingPatchWSIDatasetTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "openslide"
diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py
index 276bd1e372..6fc9240a13 100644
--- a/tests/test_sliding_window_hovernet_inference.py
+++ b/tests/test_sliding_window_hovernet_inference.py
@@ -36,6 +36,7 @@
class TestSlidingWindowHoVerNetInference(unittest.TestCase):
+
@parameterized.expand(TEST_CASES_PADDING)
def test_sliding_window_with_padding(
self, key, image_shape, roi_shape, sw_batch_size, overlap, mode, device, extra_input_padding
diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py
index 8f0c074403..33b38a5bc7 100644
--- a/tests/test_sliding_window_inference.py
+++ b/tests/test_sliding_window_inference.py
@@ -70,8 +70,10 @@
class TestSlidingWindowInference(unittest.TestCase):
+
@parameterized.expand(BUFFER_CASES)
def test_buffers(self, size_params, buffer_steps, buffer_dim, device_params):
+
def mult_two(patch, *args, **kwargs):
return 2.0 * patch
diff --git a/tests/test_sliding_window_splitter.py b/tests/test_sliding_window_splitter.py
index 015293cbee..ad136c61a4 100644
--- a/tests/test_sliding_window_splitter.py
+++ b/tests/test_sliding_window_splitter.py
@@ -236,6 +236,7 @@ def missing_parameter_filter(patch):
class SlidingWindowSplitterTests(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_TENSOR_0,
diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py
index 0e2a79fef3..bb43060469 100644
--- a/tests/test_smartcachedataset.py
+++ b/tests/test_smartcachedataset.py
@@ -38,6 +38,7 @@
class TestSmartCacheDataset(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_shape(self, replace_rate, num_replace_workers, transform):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4))
diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py
index c525311478..ca010641c4 100644
--- a/tests/test_smooth_field.py
+++ b/tests/test_smooth_field.py
@@ -88,6 +88,7 @@
class TestSmoothField(unittest.TestCase):
+
@parameterized.expand(TESTS_CONTRAST)
def test_rand_smooth_field_adjust_contrastd(self, input_param, input_data, expected_val):
g = RandSmoothFieldAdjustContrastd(**input_param)
diff --git a/tests/test_some_of.py b/tests/test_some_of.py
index 8880c376b9..3723732d51 100644
--- a/tests/test_some_of.py
+++ b/tests/test_some_of.py
@@ -31,21 +31,25 @@
class A(Transform):
+
def __call__(self, x):
return 2 * x
class B(Transform):
+
def __call__(self, x):
return 3 * x
class C(Transform):
+
def __call__(self, x):
return 5 * x
class D(Transform):
+
def __call__(self, x):
return 7 * x
@@ -71,6 +75,7 @@ def __call__(self, x):
class TestSomeOf(unittest.TestCase):
+
def setUp(self):
set_determinism(seed=0)
@@ -221,6 +226,7 @@ def test_bad_num_transforms(self):
class TestSomeOfAPITests(unittest.TestCase):
+
@staticmethod
def data_from_keys(keys):
if keys is None:
diff --git a/tests/test_spacing.py b/tests/test_spacing.py
index 8b664641d7..c9a6291c78 100644
--- a/tests/test_spacing.py
+++ b/tests/test_spacing.py
@@ -271,6 +271,7 @@
@skip_if_quick
class TestSpacingCase(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_spacing(
self,
diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py
index 36986b2706..1cecaabced 100644
--- a/tests/test_spacingd.py
+++ b/tests/test_spacingd.py
@@ -105,6 +105,7 @@
class TestSpacingDCase(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, device):
data = {k: v.to(device) for k, v in data.items()}
diff --git a/tests/test_spatial_combine_transforms.py b/tests/test_spatial_combine_transforms.py
index 8594daed16..8479e9084b 100644
--- a/tests/test_spatial_combine_transforms.py
+++ b/tests/test_spatial_combine_transforms.py
@@ -132,6 +132,7 @@
class CombineLazyTest(unittest.TestCase):
+
@parameterized.expand(TEST_2D + TEST_3D)
def test_combine_transforms(self, input_shape, funcs):
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py
index b513bd0f05..e64b242128 100644
--- a/tests/test_spatial_resample.py
+++ b/tests/test_spatial_resample.py
@@ -133,6 +133,7 @@
class TestSpatialResample(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_flips(self, img, device, data_param, expected_output):
for p in TEST_NDARRAYS_ALL:
diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py
index ebe3eb6e4f..541015cc34 100644
--- a/tests/test_spatial_resampled.py
+++ b/tests/test_spatial_resampled.py
@@ -87,6 +87,7 @@
class TestSpatialResample(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):
img = MetaTensor(img, affine=torch.eye(4)).to(device)
diff --git a/tests/test_spectral_loss.py b/tests/test_spectral_loss.py
index 21b5c48de4..f62ae9030b 100644
--- a/tests/test_spectral_loss.py
+++ b/tests/test_spectral_loss.py
@@ -63,6 +63,7 @@
class TestJukeboxLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_results(self, input_param, input_data, expected_val):
results = JukeboxLoss(**input_param).forward(**input_data)
diff --git a/tests/test_splitdim.py b/tests/test_splitdim.py
index 6c678a6bc2..f557f44142 100644
--- a/tests/test_splitdim.py
+++ b/tests/test_splitdim.py
@@ -26,6 +26,7 @@
class TestSplitDim(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_correct_shape(self, shape, keepdim, im_type):
arr = im_type(np.random.rand(*shape))
diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py
index 130a214345..3f818f905b 100644
--- a/tests/test_squeeze_unsqueeze.py
+++ b/tests/test_squeeze_unsqueeze.py
@@ -61,6 +61,7 @@
class TestUnsqueeze(unittest.TestCase):
+
@parameterized.expand(RIGHT_CASES + ALL_CASES)
def test_unsqueeze_right(self, arr, ndim, shape):
self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)
diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py
index 6673fd25c1..a295d20ef5 100644
--- a/tests/test_squeezedim.py
+++ b/tests/test_squeezedim.py
@@ -32,6 +32,7 @@
class TestSqueezeDim(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, test_data, expected_shape):
result = SqueezeDim(**input_param)(test_data)
diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py
index 9fa9d84030..934479563d 100644
--- a/tests/test_squeezedimd.py
+++ b/tests/test_squeezedimd.py
@@ -80,6 +80,7 @@
class TestSqueezeDim(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, input_param, test_data, expected_shape):
result = SqueezeDimd(**input_param)(test_data)
diff --git a/tests/test_ssim_loss.py b/tests/test_ssim_loss.py
index db80eb80db..7fa593b956 100644
--- a/tests/test_ssim_loss.py
+++ b/tests/test_ssim_loss.py
@@ -23,6 +23,7 @@
class TestSSIMLoss(unittest.TestCase):
+
def test_shape(self):
set_determinism(0)
preds = torch.abs(torch.randn(2, 3, 16, 16))
diff --git a/tests/test_ssim_metric.py b/tests/test_ssim_metric.py
index 467e478937..d79107e999 100644
--- a/tests/test_ssim_metric.py
+++ b/tests/test_ssim_metric.py
@@ -20,6 +20,7 @@
class TestSSIMMetric(unittest.TestCase):
+
def test2d_gaussian(self):
set_determinism(0)
preds = torch.abs(torch.randn(2, 3, 16, 16))
diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py
index 2037dc3951..22c2836239 100644
--- a/tests/test_state_cacher.py
+++ b/tests/test_state_cacher.py
@@ -36,6 +36,7 @@
class TestStateCacher(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_state_cacher(self, data_obj, params):
key = "data_obj"
diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py
index af18c18aa2..b4dc1db568 100644
--- a/tests/test_std_shift_intensity.py
+++ b/tests/test_std_shift_intensity.py
@@ -21,6 +21,7 @@
class TestStdShiftIntensity(NumpyImageTestCase2D):
+
def test_value(self):
for p in TEST_NDARRAYS:
imt = p(self.imt)
diff --git a/tests/test_std_shift_intensityd.py b/tests/test_std_shift_intensityd.py
index 6cb7d416c7..73617ef4a3 100644
--- a/tests/test_std_shift_intensityd.py
+++ b/tests/test_std_shift_intensityd.py
@@ -21,6 +21,7 @@
class TestStdShiftIntensityd(NumpyImageTestCase2D):
+
def test_value(self):
key = "img"
factor = np.random.rand()
diff --git a/tests/test_str2bool.py b/tests/test_str2bool.py
index 36f99b4064..af932b1df8 100644
--- a/tests/test_str2bool.py
+++ b/tests/test_str2bool.py
@@ -17,6 +17,7 @@
class TestStr2Bool(unittest.TestCase):
+
def test_str_2_bool(self):
for i in ("yes", "true", "t", "y", "1", True):
self.assertTrue(str2bool(i))
diff --git a/tests/test_str2list.py b/tests/test_str2list.py
index b442925fb3..e1531373cb 100644
--- a/tests/test_str2list.py
+++ b/tests/test_str2list.py
@@ -17,6 +17,7 @@
class TestStr2List(unittest.TestCase):
+
def test_str_2_list(self):
for i in ("1,2,3", "1, 2, 3", "1,2e-0,3.0", [1, 2, 3]):
self.assertEqual(str2list(i), [1, 2, 3])
diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py
index a6de8dd846..5abbe57e11 100644
--- a/tests/test_subpixel_upsample.py
+++ b/tests/test_subpixel_upsample.py
@@ -68,6 +68,7 @@
class TestSUBPIXEL(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_SUBPIXEL)
def test_subpixel_shape(self, input_param, input_shape, expected_shape):
net = SubpixelUpsample(**input_param)
diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py
index 53b0d38bb2..2ef19a4eea 100644
--- a/tests/test_surface_dice.py
+++ b/tests/test_surface_dice.py
@@ -24,6 +24,7 @@
class TestAllSurfaceDiceMetrics(unittest.TestCase):
+
def test_tolerance_euclidean_distance_with_spacing(self):
batch_size = 2
n_class = 2
diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py
index 81ddee107b..85db389f80 100644
--- a/tests/test_surface_distance.py
+++ b/tests/test_surface_distance.py
@@ -142,6 +142,7 @@ def create_spherical_seg_3d(
class TestAllSurfaceMetrics(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_value(self, input_data, expected_value):
if len(input_data) == 3:
diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py
index e34e5a3c8e..5b33475c7e 100644
--- a/tests/test_swin_unetr.py
+++ b/tests/test_swin_unetr.py
@@ -76,6 +76,7 @@
class TestSWINUNETR(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_SWIN_UNETR)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py
index 116897e67d..7db3c3e77a 100644
--- a/tests/test_synthetic.py
+++ b/tests/test_synthetic.py
@@ -41,6 +41,7 @@
class TestDiceCELoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_create_test_image(self, dim, input_param, expected_img, expected_seg, expected_shape, expected_max_cls):
set_determinism(seed=0)
diff --git a/tests/test_tciadataset.py b/tests/test_tciadataset.py
index 2a3928f9aa..d996922e20 100644
--- a/tests/test_tciadataset.py
+++ b/tests/test_tciadataset.py
@@ -23,6 +23,7 @@
class TestTciaDataset(unittest.TestCase):
+
@skip_if_quick
def test_values(self):
testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py
index cbb78ec64d..746ad122b2 100644
--- a/tests/test_testtimeaugmentation.py
+++ b/tests/test_testtimeaugmentation.py
@@ -52,6 +52,7 @@
class TestTestTimeAugmentation(unittest.TestCase):
+
@staticmethod
def get_data(num_examples, input_size, data_type=np.asarray, include_label=True):
custom_create_test_image_2d = partial(
diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py
index 831803de8c..902f7a4b1d 100644
--- a/tests/test_text_encoding.py
+++ b/tests/test_text_encoding.py
@@ -18,6 +18,7 @@
class TestTextEncoder(unittest.TestCase):
+
def test_test_encoding_shape(self):
with skip_if_downloading_fails():
# test 2D encoder
diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py
index ab5dba77be..2b7da2c0b0 100644
--- a/tests/test_thread_buffer.py
+++ b/tests/test_thread_buffer.py
@@ -24,6 +24,7 @@
class TestDataLoader(unittest.TestCase):
+
def setUp(self):
super().setUp()
diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py
index ca9fb244fc..9551dec703 100644
--- a/tests/test_threadcontainer.py
+++ b/tests/test_threadcontainer.py
@@ -36,6 +36,7 @@
class TestThreadContainer(unittest.TestCase):
+
@SkipIfNoModule("ignite")
def test_container(self):
net = torch.nn.Conv2d(1, 1, 3, padding=1)
diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py
index 7fb28d413f..97c80eebcd 100644
--- a/tests/test_threshold_intensity.py
+++ b/tests/test_threshold_intensity.py
@@ -27,6 +27,7 @@
class TestThresholdIntensity(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_value):
test_data = in_type(np.arange(10))
diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py
index d5e7e5f517..867ebfe952 100644
--- a/tests/test_threshold_intensityd.py
+++ b/tests/test_threshold_intensityd.py
@@ -45,6 +45,7 @@
class TestThresholdIntensityd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_value):
test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))}
diff --git a/tests/test_timedcall_dist.py b/tests/test_timedcall_dist.py
index af7cf8720f..a814a99b25 100644
--- a/tests/test_timedcall_dist.py
+++ b/tests/test_timedcall_dist.py
@@ -50,6 +50,7 @@ def case_1_seconds_bad(arg=None):
class TestTimedCall(unittest.TestCase):
+
def test_good_call(self):
output = case_1_seconds()
self.assertEqual(output, "good")
diff --git a/tests/test_to_contiguous.py b/tests/test_to_contiguous.py
index 03733b9775..73a9ca27f6 100644
--- a/tests/test_to_contiguous.py
+++ b/tests/test_to_contiguous.py
@@ -21,6 +21,7 @@
class TestToContiguous(unittest.TestCase):
+
def test_contiguous_dict(self):
tochange = np.moveaxis(np.zeros((2, 3, 4)), 0, -1)
test_dict = {"test_key": [[1]], 0: np.array(0), 1: np.array([0]), "nested": {"nested": [tochange]}}
diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py
index 12a377181d..5a1754e7c5 100644
--- a/tests/test_to_cupy.py
+++ b/tests/test_to_cupy.py
@@ -26,6 +26,7 @@
@skipUnless(HAS_CUPY, "CuPy is required.")
class TestToCupy(unittest.TestCase):
+
def test_cupy_input(self):
test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32)
test_data = cp.rot90(test_data)
diff --git a/tests/test_to_cupyd.py b/tests/test_to_cupyd.py
index e9a3488489..a07ab671e1 100644
--- a/tests/test_to_cupyd.py
+++ b/tests/test_to_cupyd.py
@@ -26,6 +26,7 @@
@skipUnless(HAS_CUPY, "CuPy is required.")
class TestToCupyd(unittest.TestCase):
+
def test_cupy_input(self):
test_data = cp.array([[1, 2], [3, 4]])
test_data = cp.rot90(test_data)
diff --git a/tests/test_to_device.py b/tests/test_to_device.py
index cad2b65316..6a13ffca99 100644
--- a/tests/test_to_device.py
+++ b/tests/test_to_device.py
@@ -30,6 +30,7 @@
@skip_if_no_cuda
class TestToDevice(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, device):
converter = ToDevice(device=device, non_blocking=True)
diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py
index 093c3b0c4d..19c2d0761f 100644
--- a/tests/test_to_deviced.py
+++ b/tests/test_to_deviced.py
@@ -22,6 +22,7 @@
@skip_if_no_cuda
class TestToDeviced(unittest.TestCase):
+
def test_value(self):
device = "cuda:0"
data = [{"img": torch.tensor(i)} for i in range(4)]
diff --git a/tests/test_to_from_meta_tensord.py b/tests/test_to_from_meta_tensord.py
index 470826313a..fe777cec77 100644
--- a/tests/test_to_from_meta_tensord.py
+++ b/tests/test_to_from_meta_tensord.py
@@ -42,6 +42,7 @@ def rand_string(min_len=5, max_len=10):
@unittest.skipIf(config.USE_META_DICT, "skipping not metatensor")
class TestToFromMetaTensord(unittest.TestCase):
+
@staticmethod
def get_im(shape=None, dtype=None, device=None):
if shape is None:
diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py
index 0c604fb9d4..8f7cf34865 100644
--- a/tests/test_to_numpy.py
+++ b/tests/test_to_numpy.py
@@ -25,6 +25,7 @@
class TestToNumpy(unittest.TestCase):
+
@skipUnless(HAS_CUPY, "CuPy is required.")
def test_cupy_input(self):
test_data = cp.array([[1, 2], [3, 4]])
diff --git a/tests/test_to_numpyd.py b/tests/test_to_numpyd.py
index d25bdf14a5..ae9b4c84b3 100644
--- a/tests/test_to_numpyd.py
+++ b/tests/test_to_numpyd.py
@@ -25,6 +25,7 @@
class TestToNumpyd(unittest.TestCase):
+
@skipUnless(HAS_CUPY, "CuPy is required.")
def test_cupy_input(self):
test_data = cp.array([[1, 2], [3, 4]])
diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py
index 52307900af..48dba6fa68 100644
--- a/tests/test_to_onehot.py
+++ b/tests/test_to_onehot.py
@@ -44,6 +44,7 @@
class TestToOneHot(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_data, expected_shape, expected_result=None):
result = one_hot(**input_data)
diff --git a/tests/test_to_pil.py b/tests/test_to_pil.py
index e4f74f6e1e..352e10bcc1 100644
--- a/tests/test_to_pil.py
+++ b/tests/test_to_pil.py
@@ -40,6 +40,7 @@
class TestToPIL(unittest.TestCase):
+
@parameterized.expand(TESTS)
@skipUnless(has_pil, "Requires `pillow` package.")
def test_value(self, test_data):
diff --git a/tests/test_to_pild.py b/tests/test_to_pild.py
index 4eb5999b15..1a0232e134 100644
--- a/tests/test_to_pild.py
+++ b/tests/test_to_pild.py
@@ -38,6 +38,7 @@
class TestToPIL(unittest.TestCase):
+
@parameterized.expand(TESTS)
@skipUnless(has_pil, "Requires `pillow` package.")
def test_values(self, input_param, test_data):
diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py
index cde845c246..50df80128b 100644
--- a/tests/test_to_tensor.py
+++ b/tests/test_to_tensor.py
@@ -33,6 +33,7 @@
class TestToTensor(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_array_input(self, test_data, expected_shape):
result = ToTensor(dtype=torch.float32, device="cpu", wrap_sequence=True)(test_data)
diff --git a/tests/test_to_tensord.py b/tests/test_to_tensord.py
index 82456786fd..1eab7b9485 100644
--- a/tests/test_to_tensord.py
+++ b/tests/test_to_tensord.py
@@ -34,6 +34,7 @@
class TestToTensord(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_array_input(self, test_data, expected_shape):
test_data = {"img": test_data}
diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py
index ec24f388f1..6f8f231829 100644
--- a/tests/test_torchscript_utils.py
+++ b/tests/test_torchscript_utils.py
@@ -23,11 +23,13 @@
class TestModule(torch.nn.Module):
+
def forward(self, x):
return x + 10
class TestTorchscript(unittest.TestCase):
+
def test_save_net_with_metadata(self):
"""Save a network without metadata to a file."""
m = torch.jit.script(TestModule())
diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py
index 9cd536aa6f..2931b0c1a8 100644
--- a/tests/test_torchvision.py
+++ b/tests/test_torchvision.py
@@ -55,6 +55,7 @@
class TestTorchVision(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py
index e913b2b9b1..322cce1161 100644
--- a/tests/test_torchvision_fc_model.py
+++ b/tests/test_torchvision_fc_model.py
@@ -153,6 +153,7 @@
class TestTorchVisionFCModel(unittest.TestCase):
+
@parameterized.expand(
[TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]
+ ([TEST_CASE_8] if has_enum else [])
@@ -187,6 +188,7 @@ def test_with_pretrained(self, input_param, input_shape, expected_shape, expecte
class TestLookup(unittest.TestCase):
+
def test_get_module(self):
net = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32, 64), strides=(2, 2, 2, 2))
self.assertEqual(look_up_named_module("", net), net)
diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py
index b2a6bcafc5..ec09692df9 100644
--- a/tests/test_torchvisiond.py
+++ b/tests/test_torchvisiond.py
@@ -52,6 +52,7 @@
class TestTorchVisiond(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py
index 42906c84d2..dd139053e3 100644
--- a/tests/test_traceable_transform.py
+++ b/tests/test_traceable_transform.py
@@ -17,6 +17,7 @@
class _TraceTest(TraceableTransform):
+
def __call__(self, data):
self.push_transform(data)
return data
@@ -27,6 +28,7 @@ def pop(self, data):
class TestTraceable(unittest.TestCase):
+
def test_default(self):
expected_key = "_transforms"
a = _TraceTest()
diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py
index 6136e2f7db..ae99f91363 100644
--- a/tests/test_train_mode.py
+++ b/tests/test_train_mode.py
@@ -19,6 +19,7 @@
class TestEvalMode(unittest.TestCase):
+
def test_eval_mode(self):
t = torch.rand(1, 1, 4, 4)
p = torch.nn.Conv2d(1, 1, 3)
diff --git a/tests/test_trainable_bilateral.py b/tests/test_trainable_bilateral.py
index 43b628be80..c69eff4071 100644
--- a/tests/test_trainable_bilateral.py
+++ b/tests/test_trainable_bilateral.py
@@ -273,6 +273,7 @@
@skip_if_no_cpp_extension
class BilateralFilterTestCaseCpuPrecise(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cpu_precise(self, test_case_description, sigmas, input, expected):
# Params to determine the implementation to test
@@ -371,6 +372,7 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expec
@skip_if_no_cuda
@skip_if_no_cpp_extension
class BilateralFilterTestCaseCudaPrecise(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cuda_precise(self, test_case_description, sigmas, input, expected):
# Skip this test
diff --git a/tests/test_trainable_joint_bilateral.py b/tests/test_trainable_joint_bilateral.py
index 8a9c69bda4..4263683ce2 100644
--- a/tests/test_trainable_joint_bilateral.py
+++ b/tests/test_trainable_joint_bilateral.py
@@ -358,6 +358,7 @@
@skip_if_no_cpp_extension
@skip_if_quick
class JointBilateralFilterTestCaseCpuPrecise(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cpu_precise(self, test_case_description, sigmas, input, guide, expected):
# Params to determine the implementation to test
@@ -481,6 +482,7 @@ def test_cpu_precise_backwards(self, test_case_description, sigmas, input, guide
@skip_if_no_cuda
@skip_if_no_cpp_extension
class JointBilateralFilterTestCaseCudaPrecise(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_cuda_precise(self, test_case_description, sigmas, input, guide, expected):
# Skip this test
diff --git a/tests/test_transchex.py b/tests/test_transchex.py
index 9ad847cdaa..481c20e285 100644
--- a/tests/test_transchex.py
+++ b/tests/test_transchex.py
@@ -47,6 +47,7 @@
@skip_if_quick
class TestTranschex(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_TRANSCHEX)
def test_shape(self, input_param, expected_shape):
net = Transchex(**input_param)
diff --git a/tests/test_transform.py b/tests/test_transform.py
index ea738eaac3..9b05133391 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -20,6 +20,7 @@
class FaultyTransform(mt.Transform):
+
def __call__(self, _):
raise RuntimeError
@@ -29,6 +30,7 @@ def faulty_lambda(_):
class TestTransform(unittest.TestCase):
+
@classmethod
def setUpClass(cls):
super(__class__, cls).setUpClass()
diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py
index 914336668d..5a8dbba83c 100644
--- a/tests/test_transformerblock.py
+++ b/tests/test_transformerblock.py
@@ -39,6 +39,7 @@
class TestTransformerBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_TRANSFORMERBLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
net = TransformerBlock(**input_param)
diff --git a/tests/test_transpose.py b/tests/test_transpose.py
index 0c9ae1c7e3..2f5ccd1235 100644
--- a/tests/test_transpose.py
+++ b/tests/test_transpose.py
@@ -27,6 +27,7 @@
class TestTranspose(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_transpose(self, im, indices):
tr = Transpose(indices)
diff --git a/tests/test_transposed.py b/tests/test_transposed.py
index ab80520fc9..e7c6ecbe8a 100644
--- a/tests/test_transposed.py
+++ b/tests/test_transposed.py
@@ -30,6 +30,7 @@
class TestTranspose(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_transpose(self, im, indices):
data = {"i": deepcopy(im), "j": deepcopy(im)}
diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py
index d1175f40c5..efe1f2cdf3 100644
--- a/tests/test_tversky_loss.py
+++ b/tests/test_tversky_loss.py
@@ -148,6 +148,7 @@
class TestTverskyLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = TverskyLoss(**input_param).forward(**input_data)
diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py
index fbf0c4fe97..f672961700 100644
--- a/tests/test_ultrasound_confidence_map_transform.py
+++ b/tests/test_ultrasound_confidence_map_transform.py
@@ -518,6 +518,7 @@
class TestUltrasoundConfidenceMapTransform(unittest.TestCase):
+
def setUp(self):
self.input_img_np = np.expand_dims(TEST_INPUT, axis=0) # mock image (numpy array)
self.input_mask_np = np.expand_dims(TEST_MASK, axis=0) # mock mask (numpy array)
diff --git a/tests/test_unet.py b/tests/test_unet.py
index 9cb4af3379..1fb98f84b0 100644
--- a/tests/test_unet.py
+++ b/tests/test_unet.py
@@ -165,6 +165,7 @@
class TestUNET(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = UNet(**input_param).to(device)
diff --git a/tests/test_unetr.py b/tests/test_unetr.py
index 406d30aa12..46018d2bc0 100644
--- a/tests/test_unetr.py
+++ b/tests/test_unetr.py
@@ -57,6 +57,7 @@
@skip_if_quick
class TestUNETR(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_UNETR)
def test_shape(self, input_param, input_shape, expected_shape):
net = UNETR(**input_param)
diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py
index 60004be25e..9701557ed6 100644
--- a/tests/test_unetr_block.py
+++ b/tests/test_unetr_block.py
@@ -102,6 +102,7 @@
class TestResBasicBlock(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_UNETR_BASIC_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
for net in [UnetrBasicBlock(**input_param)]:
@@ -124,6 +125,7 @@ def test_script(self):
class TestUpBlock(unittest.TestCase):
+
@parameterized.expand(TEST_UP_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape, skip_shape):
net = UnetrUpBlock(**input_param)
@@ -140,6 +142,7 @@ def test_script(self):
class TestPrUpBlock(unittest.TestCase):
+
@parameterized.expand(TEST_PRUP_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
net = UnetrPrUpBlock(**input_param)
diff --git a/tests/test_unified_focal_loss.py b/tests/test_unified_focal_loss.py
index 0e7217e2b4..3b868a560e 100644
--- a/tests/test_unified_focal_loss.py
+++ b/tests/test_unified_focal_loss.py
@@ -38,6 +38,7 @@
class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_result(self, input_data, expected_val):
loss = AsymmetricUnifiedFocalLoss()
diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py
index a82a31b064..e4890c83bc 100644
--- a/tests/test_upsample_block.py
+++ b/tests/test_upsample_block.py
@@ -121,6 +121,7 @@
class TestUpsample(unittest.TestCase):
+
@parameterized.expand(TEST_CASES + TEST_CASES_EQ + TEST_CASES_EQ2)
def test_shape(self, input_param, input_shape, expected_shape):
net = UpSample(**input_param)
diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py
index 619ae8aee3..6e655289e4 100644
--- a/tests/test_utils_pytorch_numpy_unification.py
+++ b/tests/test_utils_pytorch_numpy_unification.py
@@ -29,6 +29,7 @@
class TestPytorchNumpyUnification(unittest.TestCase):
+
def setUp(self) -> None:
set_determinism(0)
diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py
index b050983d2c..e957dcfb61 100644
--- a/tests/test_varautoencoder.py
+++ b/tests/test_varautoencoder.py
@@ -108,6 +108,7 @@
class TestVarAutoEncoder(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = VarAutoEncoder(**input_param).to(device)
diff --git a/tests/test_varnet.py b/tests/test_varnet.py
index 3ec6b0f087..a46d58d6a2 100644
--- a/tests/test_varnet.py
+++ b/tests/test_varnet.py
@@ -32,6 +32,7 @@
class TestVarNet(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):
net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades).to(device)
diff --git a/tests/test_version.py b/tests/test_version.py
index 15f8cd36c6..35ce8d9a2f 100644
--- a/tests/test_version.py
+++ b/tests/test_version.py
@@ -75,6 +75,7 @@ def _pairwise(iterable):
class TestVersionCompare(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_compare_leq(self, a, b, expected=True):
"""Test version_leq with `a` and `b`"""
diff --git a/tests/test_video_datasets.py b/tests/test_video_datasets.py
index 790feb51ee..6e344e1caa 100644
--- a/tests/test_video_datasets.py
+++ b/tests/test_video_datasets.py
@@ -31,6 +31,7 @@
class Base:
+
class TestVideoDataset(unittest.TestCase):
video_source: int | str
ds: type[VideoDataset]
diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py
index bb3ff7237a..b641599af2 100644
--- a/tests/test_vis_cam.py
+++ b/tests/test_vis_cam.py
@@ -67,6 +67,7 @@
class TestClassActivationMap(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, input_data, expected_shape):
if input_data["model"] == "densenet2d":
diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py
index 0fbe328c83..e9db0af240 100644
--- a/tests/test_vis_gradbased.py
+++ b/tests/test_vis_gradbased.py
@@ -21,6 +21,7 @@
class DenseNetAdjoint(DenseNet121):
+
def __call__(self, x, adjoint_info):
if adjoint_info != 42:
raise ValueError
@@ -48,6 +49,7 @@ def __call__(self, x, adjoint_info):
class TestGradientClassActivationMap(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, vis_type, model, shape):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py
index 4b554de0aa..325b74b3ce 100644
--- a/tests/test_vis_gradcam.py
+++ b/tests/test_vis_gradcam.py
@@ -24,6 +24,7 @@
class DenseNetAdjoint(DenseNet121):
+
def __call__(self, x, adjoint_info):
if adjoint_info != 42:
raise ValueError
@@ -149,6 +150,7 @@ def __call__(self, x, adjoint_info):
@skip_if_quick
class TestGradientClassActivationMap(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_shape(self, cam_class, input_data, expected_shape):
if input_data["model"] == "densenet2d":
diff --git a/tests/test_vit.py b/tests/test_vit.py
index f911c2d5c9..a84883cba0 100644
--- a/tests/test_vit.py
+++ b/tests/test_vit.py
@@ -61,6 +61,7 @@
@skip_if_quick
class TestViT(unittest.TestCase):
+
@parameterized.expand(TEST_CASE_Vit)
def test_shape(self, input_param, input_shape, expected_shape):
net = ViT(**input_param)
diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py
index 5e95d3c7fb..cc3d493bb3 100644
--- a/tests/test_vitautoenc.py
+++ b/tests/test_vitautoenc.py
@@ -66,6 +66,7 @@
@skip_if_quick
class TestVitAutoenc(unittest.TestCase):
+
def setUp(self):
self.threads = torch.get_num_threads()
torch.set_num_threads(4)
diff --git a/tests/test_vnet.py b/tests/test_vnet.py
index 633893ce51..0ebf060434 100644
--- a/tests/test_vnet.py
+++ b/tests/test_vnet.py
@@ -55,6 +55,7 @@
class TestVNet(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_VNET_2D_1,
diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py
index 32ff120c5d..4abdd0b050 100644
--- a/tests/test_vote_ensemble.py
+++ b/tests/test_vote_ensemble.py
@@ -71,6 +71,7 @@
class TestVoteEnsemble(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_param, img, expected_value):
result = VoteEnsemble(**input_param)(img)
diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py
index 17f9d54835..957133d7fc 100644
--- a/tests/test_vote_ensembled.py
+++ b/tests/test_vote_ensembled.py
@@ -86,6 +86,7 @@
class TestVoteEnsembled(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_value(self, input_param, img, expected_value):
result = VoteEnsembled(**input_param)(img)
diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py
index 53ef2fc18f..ef420ef20c 100644
--- a/tests/test_voxelmorph.py
+++ b/tests/test_voxelmorph.py
@@ -245,6 +245,7 @@
class TestVOXELMORPH(unittest.TestCase):
+
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = VoxelMorphUNet(**input_param).to(device)
diff --git a/tests/test_warp.py b/tests/test_warp.py
index e614973f90..bac595224f 100644
--- a/tests/test_warp.py
+++ b/tests/test_warp.py
@@ -106,6 +106,7 @@
@skip_if_quick
class TestWarp(unittest.TestCase):
+
def setUp(self):
config = testing_data_config("images", "Prostate_T2W_AX_1")
download_url_or_skip_test(
diff --git a/tests/test_watershed.py b/tests/test_watershed.py
index a5a232ba3c..3f7a29bfe7 100644
--- a/tests/test_watershed.py
+++ b/tests/test_watershed.py
@@ -43,6 +43,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
@unittest.skipUnless(has_scipy, "Requires scipy library.")
class TestWatershed(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_output(self, args, image, hover_map, expected_shape):
mask = GenerateWatershedMask()(image)
diff --git a/tests/test_watershedd.py b/tests/test_watershedd.py
index c12f5ad140..fc44996be4 100644
--- a/tests/test_watershedd.py
+++ b/tests/test_watershedd.py
@@ -48,6 +48,7 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
@unittest.skipUnless(has_scipy, "Requires scipy library.")
class TestWatershedd(unittest.TestCase):
+
@parameterized.expand(TESTS)
def test_output(self, args, image, hover_map, expected_shape):
data = {"output": image, "hover_map": hover_map}
diff --git a/tests/test_weight_init.py b/tests/test_weight_init.py
index 376faacc56..a682ec6cc9 100644
--- a/tests/test_weight_init.py
+++ b/tests/test_weight_init.py
@@ -32,6 +32,7 @@
class TestWeightInit(unittest.TestCase):
+
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape):
im = torch.rand(input_shape)
diff --git a/tests/test_weighted_random_sampler_dist.py b/tests/test_weighted_random_sampler_dist.py
index d38bab54f0..8e37482da6 100644
--- a/tests/test_weighted_random_sampler_dist.py
+++ b/tests/test_weighted_random_sampler_dist.py
@@ -24,6 +24,7 @@
@skip_if_windows
@skip_if_darwin
class DistributedWeightedRandomSamplerTest(DistTestCase):
+
@DistCall(nnodes=1, nproc_per_node=2)
def test_sampling(self):
data = [1, 2, 3, 4, 5]
diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py
index ec55654f07..427f64c705 100644
--- a/tests/test_with_allow_missing_keys.py
+++ b/tests/test_with_allow_missing_keys.py
@@ -19,6 +19,7 @@
class TestWithAllowMissingKeysMode(unittest.TestCase):
+
def setUp(self):
self.data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)}
diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py
index 4f61e43fe1..1013f15d85 100644
--- a/tests/test_write_metrics_reports.py
+++ b/tests/test_write_metrics_reports.py
@@ -23,6 +23,7 @@
class TestWriteMetricsReports(unittest.TestCase):
+
def test_content(self):
with tempfile.TemporaryDirectory() as tempdir:
write_metrics_reports(
diff --git a/tests/test_wsi_sliding_window_splitter.py b/tests/test_wsi_sliding_window_splitter.py
index ac1a136489..c510ece272 100644
--- a/tests/test_wsi_sliding_window_splitter.py
+++ b/tests/test_wsi_sliding_window_splitter.py
@@ -102,6 +102,7 @@
# Filtering functions test cases
def gen_location_filter(locations):
+
def my_filter(patch, loc):
if loc in locations:
return False
@@ -198,6 +199,7 @@ def setUpModule():
class WSISlidingWindowSplitterTests(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_WSI_0_BASE,
diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py
index aae2b0dbaf..99a86c5ac8 100644
--- a/tests/test_wsireader.py
+++ b/tests/test_wsireader.py
@@ -402,6 +402,7 @@ def setUpModule():
class WSIReaderTests:
+
class Tests(unittest.TestCase):
backend = None
@@ -640,6 +641,7 @@ def test_errors(self, file_path, reader_kwargs, patch_info, exception):
@skipUnless(has_cucim, "Requires cucim")
class TestCuCIM(WSIReaderTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "cucim"
@@ -647,6 +649,7 @@ def setUpClass(cls):
@skipUnless(has_osl, "Requires openslide")
class TestOpenSlide(WSIReaderTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "openslide"
@@ -654,6 +657,7 @@ def setUpClass(cls):
@skipUnless(has_tiff, "Requires tifffile")
class TestTiffFile(WSIReaderTests.Tests):
+
@classmethod
def setUpClass(cls):
cls.backend = "tifffile"
diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py
index c4c7fad5da..de7fad48da 100644
--- a/tests/test_zarr_avg_merger.py
+++ b/tests/test_zarr_avg_merger.py
@@ -256,6 +256,7 @@
@unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)")
class ZarrAvgMergerTests(unittest.TestCase):
+
@parameterized.expand(
[
TEST_CASE_0_DEFAULT_DTYPE,
diff --git a/tests/test_zipdataset.py b/tests/test_zipdataset.py
index de8a8e80d6..2939ff3f49 100644
--- a/tests/test_zipdataset.py
+++ b/tests/test_zipdataset.py
@@ -20,6 +20,7 @@
class Dataset_(torch.utils.data.Dataset):
+
def __init__(self, length, index_only=True):
self.len = length
self.index_only = index_only
@@ -48,6 +49,7 @@ def __getitem__(self, index):
class TestZipDataset(unittest.TestCase):
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, datasets, transform, expected_output, expected_length):
test_dataset = ZipDataset(datasets=datasets, transform=transform)
diff --git a/tests/test_zoom.py b/tests/test_zoom.py
index e1ea3c25a3..2db2df4486 100644
--- a/tests/test_zoom.py
+++ b/tests/test_zoom.py
@@ -43,6 +43,7 @@
class TestZoom(NumpyImageTestCase2D):
+
@parameterized.expand(VALID_CASES)
def test_pending_ops(self, zoom, mode, align_corners=False, keep_size=False):
im = MetaTensor(self.imt[0], meta={"a": "b", "affine": DEFAULT_TEST_AFFINE})
diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py
index dc39a4f1c2..ae8e688d96 100644
--- a/tests/test_zoom_affine.py
+++ b/tests/test_zoom_affine.py
@@ -64,6 +64,7 @@
class TestZoomAffine(unittest.TestCase):
+
@parameterized.expand(VALID_CASES)
def test_correct(self, affine, scale, expected):
output = zoom_affine(affine, scale, diagonal=False)
diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py
index 1dcbf98572..ad91f398ff 100644
--- a/tests/test_zoomd.py
+++ b/tests/test_zoomd.py
@@ -34,6 +34,7 @@
class TestZoomd(NumpyImageTestCase2D):
+
@parameterized.expand(VALID_CASES)
def test_correct_results(self, zoom, mode, keep_size, align_corners=None):
key = "img"
diff --git a/tests/utils.py b/tests/utils.py
index ee800598bb..ea73a3ed81 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -677,6 +677,7 @@ def setUp(self):
class TorchImageTestCase2D(NumpyImageTestCase2D):
+
def setUp(self):
NumpyImageTestCase2D.setUp(self)
self.imt = torch.tensor(self.imt)
@@ -707,6 +708,7 @@ def setUp(self):
class TorchImageTestCase3D(NumpyImageTestCase3D):
+
def setUp(self):
NumpyImageTestCase3D.setUp(self)
self.imt = torch.tensor(self.imt)
From 449c2fb0e5eeb1bf7423580cee8effe29198ccb5 Mon Sep 17 00:00:00 2001
From: Ibrahim Hadzic
Date: Mon, 5 Feb 2024 22:30:53 -0500
Subject: [PATCH 62/88] Instantiation mode `"partial"` to `"callable"`. Return
the `_target_` component as-is when in `_mode_="callable"` and no kwargs are
specified (#7413)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### Description
A `_target_` component with `_mode_="partial"` will still be wrapped in
`functools.partial` even when no kwargs are passed:
`functool.partial(component)`. In such cases, the component can just be
returned as-is.
If you agree with this, I will add tests for it. Thank you!
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Ibrahim Hadzic
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/config_syntax.md | 7 ++++---
monai/bundle/config_item.py | 2 +-
monai/utils/enums.py | 2 +-
monai/utils/module.py | 11 +++++++----
tests/test_config_item.py | 2 +-
tests/test_config_parser.py | 6 ++----
6 files changed, 16 insertions(+), 14 deletions(-)
diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md
index c1e3d5cbe9..c932879b5a 100644
--- a/docs/source/config_syntax.md
+++ b/docs/source/config_syntax.md
@@ -168,9 +168,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `_mode_` specifies the operating mode when the component is instantiated or the callable is called.
it currently supports the following values:
- `"default"` (default) -- return the return value of ``_target_(**kwargs)``
- - `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often
- useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to
- call it with additional arguments later).
+ - `"callable"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a
+ partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function
+ that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call
+ it with additional arguments later.
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).
diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py
index c6da0a73de..844d5b30bf 100644
--- a/monai/bundle/config_item.py
+++ b/monai/bundle/config_item.py
@@ -181,7 +181,7 @@ class ConfigComponent(ConfigItem, Instantiable):
- ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``:
- ``"default"``: returns ``component(**kwargs)``
- - ``"partial"``: returns ``functools.partial(component, **kwargs)``
+ - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
Other fields in the config content are input arguments to the python module.
diff --git a/monai/utils/enums.py b/monai/utils/enums.py
index a0847dd76c..b786e92151 100644
--- a/monai/utils/enums.py
+++ b/monai/utils/enums.py
@@ -411,7 +411,7 @@ class CompInitMode(StrEnum):
"""
DEFAULT = "default"
- PARTIAL = "partial"
+ CALLABLE = "callable"
DEBUG = "debug"
diff --git a/monai/utils/module.py b/monai/utils/module.py
index db62e1e72b..0dcf22fcd7 100644
--- a/monai/utils/module.py
+++ b/monai/utils/module.py
@@ -231,11 +231,14 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
Args:
__path: if a string is provided, it's interpreted as the full path of the target class or function component.
- If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned.
+ If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``.
+ For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided,
+ as ``functools.partial(__path, **kwargs)`` for future invoking.
+
__mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``:
- ``"default"``: returns ``component(**kwargs)``
- - ``"partial"``: returns ``functools.partial(component, **kwargs)``
+ - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
kwargs: keyword arguments to the callable represented by ``__path``.
@@ -259,8 +262,8 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
return component
if m == CompInitMode.DEFAULT:
return component(**kwargs)
- if m == CompInitMode.PARTIAL:
- return partial(component, **kwargs)
+ if m == CompInitMode.CALLABLE:
+ return partial(component, **kwargs) if kwargs else component
if m == CompInitMode.DEBUG:
warnings.warn(
f"\n\npdb: instantiating component={component}, mode={m}\n"
diff --git a/tests/test_config_item.py b/tests/test_config_item.py
index 72f54adf0a..4909ecf6be 100644
--- a/tests/test_config_item.py
+++ b/tests/test_config_item.py
@@ -37,7 +37,7 @@
TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict]
# test non-monai modules and excludes
TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam]
-TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial]
+TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "callable"}, partial]
# test args contains "name" field
TEST_CASE_8 = [
{"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py
index 41d7aa7a4e..8947158857 100644
--- a/tests/test_config_parser.py
+++ b/tests/test_config_parser.py
@@ -72,7 +72,6 @@ def case_pdb_inst(sarg=None):
class TestClass:
-
@staticmethod
def compute(a, b, func=lambda x, y: x + y):
return func(a, b)
@@ -127,7 +126,6 @@ def __call__(self, a, b):
class TestConfigParser(unittest.TestCase):
-
def test_config_content(self):
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
parser = ConfigParser(config=test_config)
@@ -183,7 +181,7 @@ def test_function(self, config):
parser = ConfigParser(config=config, globals={"TestClass": TestClass})
for id in config:
if id in ("compute", "cls_compute"):
- parser[f"{id}#_mode_"] = "partial"
+ parser[f"{id}#_mode_"] = "callable"
func = parser.get_parsed_content(id=id)
self.assertTrue(id in parser.ref_resolver.resolved_content)
if id == "error_func":
@@ -279,7 +277,7 @@ def test_lambda_reference(self):
def test_non_str_target(self):
configs = {
- "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"},
+ "fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "callable"},
"model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2},
}
self.assertTrue(callable(ConfigParser(config=configs).fwd))
From eb8c8aadb8b72eb6a945c5a665d5dd8a0a2654d9 Mon Sep 17 00:00:00 2001
From: "Dr. Behrooz Hashemian" <3968947+drbeh@users.noreply.github.com>
Date: Tue, 6 Feb 2024 03:58:15 -0500
Subject: [PATCH 63/88] Add support for mlflow experiment name in auto3dseg
(#7442)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7441
This PR enable Auto3DSeg users to manage their runs and experiment more
efficiently in MLFlow under arbitrary experiment names, by providing
experiment name as an input parameter.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
change).
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
---------
Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/auto3dseg/auto_runner.py | 12 ++++++++-
monai/apps/auto3dseg/bundle_gen.py | 32 ++++++++++++++++++++++--
monai/apps/auto3dseg/ensemble_builder.py | 2 +-
3 files changed, 42 insertions(+), 4 deletions(-)
diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py
index e4c2d908b7..52a0824227 100644
--- a/monai/apps/auto3dseg/auto_runner.py
+++ b/monai/apps/auto3dseg/auto_runner.py
@@ -85,6 +85,7 @@ class AutoRunner:
can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer.
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote
tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None.
+ mlflow_experiment_name: the name of the experiment in MLflow server.
kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
@@ -212,6 +213,7 @@ def __init__(
templates_path_or_url: str | None = None,
allow_skip: bool = True,
mlflow_tracking_uri: str | None = None,
+ mlflow_experiment_name: str | None = None,
**kwargs: Any,
):
if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")):
@@ -253,6 +255,7 @@ def __init__(
self.hpo = hpo and has_nni
self.hpo_backend = hpo_backend
self.mlflow_tracking_uri = mlflow_tracking_uri
+ self.mlflow_experiment_name = mlflow_experiment_name
self.kwargs = deepcopy(kwargs)
# parse input config for AutoRunner param overrides
@@ -268,7 +271,13 @@ def __init__(
if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool):
setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"]
- for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config
+ for param in [
+ "algos",
+ "hpo_backend",
+ "templates_path_or_url",
+ "mlflow_tracking_uri",
+ "mlflow_experiment_name",
+ ]: # override from config
if param in self.data_src_cfg:
setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"]
@@ -813,6 +822,7 @@ def run(self):
data_stats_filename=self.datastats_filename,
data_src_cfg_name=self.data_src_cfg_name,
mlflow_tracking_uri=self.mlflow_tracking_uri,
+ mlflow_experiment_name=self.mlflow_experiment_name,
)
if self.gpu_customization:
diff --git a/monai/apps/auto3dseg/bundle_gen.py b/monai/apps/auto3dseg/bundle_gen.py
index 03b9c8bbf4..8a54d18be7 100644
--- a/monai/apps/auto3dseg/bundle_gen.py
+++ b/monai/apps/auto3dseg/bundle_gen.py
@@ -85,7 +85,8 @@ def __init__(self, template_path: PathLike):
self.template_path = template_path
self.data_stats_files = ""
self.data_list_file = ""
- self.mlflow_tracking_uri = None
+ self.mlflow_tracking_uri: str | None = None
+ self.mlflow_experiment_name: str | None = None
self.output_path = ""
self.name = ""
self.best_metric = None
@@ -139,7 +140,16 @@ def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None:
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
"""
- self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore
+ self.mlflow_tracking_uri = mlflow_tracking_uri
+
+ def set_mlflow_experiment_name(self, mlflow_experiment_name: str | None) -> None:
+ """
+ Set the experiment name for MLflow server
+
+ Args:
+ mlflow_experiment_name: a string to specify the experiment name for MLflow server.
+ """
+ self.mlflow_experiment_name = mlflow_experiment_name
def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict:
"""
@@ -447,6 +457,7 @@ class BundleGen(AlgoGen):
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
+ mlfow_experiment_name: a string to specify the experiment name for MLflow server.
.. code-block:: bash
python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
@@ -460,6 +471,7 @@ def __init__(
data_stats_filename: str | None = None,
data_src_cfg_name: str | None = None,
mlflow_tracking_uri: str | None = None,
+ mlflow_experiment_name: str | None = None,
):
if algos is None or isinstance(algos, (list, tuple, str)):
if templates_path_or_url is None:
@@ -513,6 +525,7 @@ def __init__(
self.data_stats_filename = data_stats_filename
self.data_src_cfg_name = data_src_cfg_name
self.mlflow_tracking_uri = mlflow_tracking_uri
+ self.mlflow_experiment_name = mlflow_experiment_name
self.history: list[dict] = []
def set_data_stats(self, data_stats_filename: str) -> None:
@@ -552,10 +565,23 @@ def set_mlflow_tracking_uri(self, mlflow_tracking_uri):
"""
self.mlflow_tracking_uri = mlflow_tracking_uri
+ def set_mlflow_experiment_name(self, mlflow_experiment_name):
+ """
+ Set the experiment name for MLflow server
+
+ Args:
+ mlflow_experiment_name: a string to specify the experiment name for MLflow server.
+ """
+ self.mlflow_experiment_name = mlflow_experiment_name
+
def get_mlflow_tracking_uri(self):
"""Get the tracking URI for MLflow server"""
return self.mlflow_tracking_uri
+ def get_mlflow_experiment_name(self):
+ """Get the experiment name for MLflow server"""
+ return self.mlflow_experiment_name
+
def get_history(self) -> list:
"""Get the history of the bundleAlgo object with their names/identifiers"""
return self.history
@@ -608,10 +634,12 @@ def generate(
data_stats = self.get_data_stats()
data_src_cfg = self.get_data_src()
mlflow_tracking_uri = self.get_mlflow_tracking_uri()
+ mlflow_experiment_name = self.get_mlflow_experiment_name()
gen_algo = deepcopy(algo)
gen_algo.set_data_stats(data_stats)
gen_algo.set_data_source(data_src_cfg)
gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri)
+ gen_algo.set_mlflow_experiment_name(mlflow_experiment_name)
name = f"{gen_algo.name}_{f_id}"
if allow_skip:
diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py
index e29745e5cf..b2bea806de 100644
--- a/monai/apps/auto3dseg/ensemble_builder.py
+++ b/monai/apps/auto3dseg/ensemble_builder.py
@@ -464,7 +464,7 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
)
if self.ensemble_method_name == "AlgoEnsembleBestN":
- n_best = kwargs.pop("n_best", False) or 2
+ n_best = kwargs.pop("n_best", 2)
self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)
elif self.ensemble_method_name == "AlgoEnsembleBestByFold":
self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold) # type: ignore
From 5ab247e8fa776fe82106edc22c659194b2c19f01 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Wed, 7 Feb 2024 21:47:23 +0800
Subject: [PATCH 64/88] Update gdown version (#7448)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
gdown library has been updated to fix
https://github.com/Project-MONAI/MONAI/issues/7383 since
https://github.com/wkentaro/gdown/pull/295
Update the gdown version for this PR
https://github.com/Project-MONAI/MONAI/pull/7384
### Description
A few sentences describing the changes proposed in this pull request.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/utils.py | 2 +-
requirements-dev.txt | 2 +-
setup.cfg | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/monai/apps/utils.py b/monai/apps/utils.py
index 442dbabba0..db541923b5 100644
--- a/monai/apps/utils.py
+++ b/monai/apps/utils.py
@@ -30,7 +30,7 @@
from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import
-gdown, has_gdown = optional_import("gdown", "4.6.3")
+gdown, has_gdown = optional_import("gdown", "4.7.3")
if TYPE_CHECKING:
from tqdm import tqdm
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 706980576c..b08fef874b 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,7 +1,7 @@
# Full requirements for developments
-r requirements-min.txt
pytorch-ignite==0.4.11
-gdown>=4.4.0, <=4.6.3
+gdown>=4.7.3
scipy>=1.7.1
itk>=5.2
nibabel
diff --git a/setup.cfg b/setup.cfg
index 4180ced917..229e2ace56 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -52,7 +52,7 @@ all =
scipy>=1.7.1
pillow
tensorboard
- gdown==4.6.3
+ gdown>=4.7.3
pytorch-ignite==0.4.11
torchvision
itk>=5.2
@@ -97,7 +97,7 @@ pillow =
tensorboard =
tensorboard
gdown =
- gdown==4.6.3
+ gdown>=4.7.3
ignite =
pytorch-ignite==0.4.11
torchvision =
From 4b4c4f9c1707c6b8d9ca6fd5737ee3ad32082aba Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Tue, 20 Feb 2024 22:18:00 +0800
Subject: [PATCH 65/88] Skip "test_gaussian_filter" as a workaround for blossom
killed (#7474)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
workaround for #7445
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_gaussian_filter.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/tests/test_gaussian_filter.py b/tests/test_gaussian_filter.py
index 4ab689c565..2167591c66 100644
--- a/tests/test_gaussian_filter.py
+++ b/tests/test_gaussian_filter.py
@@ -18,7 +18,7 @@
from parameterized import parameterized
from monai.networks.layers import GaussianFilter
-from tests.utils import skip_if_quick
+from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick
TEST_CASES = [[{"type": "erf", "gt": 2.0}], [{"type": "scalespace", "gt": 3.0}], [{"type": "sampled", "gt": 5.0}]]
TEST_CASES_GPU = [[{"type": "erf", "gt": 0.8, "device": "cuda"}], [{"type": "sampled", "gt": 5.0, "device": "cuda"}]]
@@ -34,6 +34,7 @@
]
+@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445
class TestGaussianFilterBackprop(unittest.TestCase):
def code_to_run(self, input_args):
@@ -94,6 +95,7 @@ def test_train_slow(self, input_args):
self.code_to_run(input_args)
+@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445
class GaussianFilterTestCase(unittest.TestCase):
def test_1d(self):
From d1de7647f242f28387cec15b565c783d5602f92f Mon Sep 17 00:00:00 2001
From: monai-bot <64792179+monai-bot@users.noreply.github.com>
Date: Tue, 20 Feb 2024 14:58:30 +0000
Subject: [PATCH 66/88] auto updates (#7463)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: monai-bot
Signed-off-by: monai-bot
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_config_parser.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py
index 8947158857..cc890a0522 100644
--- a/tests/test_config_parser.py
+++ b/tests/test_config_parser.py
@@ -72,6 +72,7 @@ def case_pdb_inst(sarg=None):
class TestClass:
+
@staticmethod
def compute(a, b, func=lambda x, y: x + y):
return func(a, b)
@@ -126,6 +127,7 @@ def __call__(self, a, b):
class TestConfigParser(unittest.TestCase):
+
def test_config_content(self):
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
parser = ConfigParser(config=test_config)
From 42a4e2c64ba9ea2ef05229448644ecbda7fd78a7 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Wed, 21 Feb 2024 20:48:37 +0800
Subject: [PATCH 67/88] Skip "test_resize" as a workaround for blossom killed
(#7484)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
workaround for #7445
### Description
`Resize` also uses `GaussianFilter` inside.
https://github.com/Project-MONAI/MONAI/blob/50f9aea67fb1ea7967020ad613ce83409261f2de/monai/transforms/spatial/functional.py#L332
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_resize.py | 11 +++++++++--
tests/test_resized.py | 10 ++++++++--
2 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/tests/test_resize.py b/tests/test_resize.py
index 33abfe4e1f..65b33ea649 100644
--- a/tests/test_resize.py
+++ b/tests/test_resize.py
@@ -21,7 +21,14 @@
from monai.data import MetaTensor, set_track_meta
from monai.transforms import Resize
from tests.lazy_transforms_utils import test_resampler_lazy
-from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, is_tf32_env, pytorch_after
+from tests.utils import (
+ TEST_NDARRAYS_ALL,
+ NumpyImageTestCase2D,
+ SkipIfAtLeastPyTorchVersion,
+ assert_allclose,
+ is_tf32_env,
+ pytorch_after,
+)
TEST_CASE_0 = [{"spatial_size": 15}, (6, 10, 15)]
@@ -39,7 +46,6 @@
class TestResize(NumpyImageTestCase2D):
-
def test_invalid_inputs(self):
with self.assertRaises(ValueError):
resize = Resize(spatial_size=(128, 128, 3), mode="order")
@@ -112,6 +118,7 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing):
)
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_1, TEST_CASE_3, TEST_CASE_4])
+ @SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445
def test_longest_shape(self, input_param, expected_shape):
input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])
input_param["size_mode"] = "longest"
diff --git a/tests/test_resized.py b/tests/test_resized.py
index ab4c9815ea..d62f29ab5c 100644
--- a/tests/test_resized.py
+++ b/tests/test_resized.py
@@ -21,7 +21,13 @@
from monai.data import MetaTensor, set_track_meta
from monai.transforms import Invertd, Resize, Resized
from tests.lazy_transforms_utils import test_resampler_lazy
-from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion
+from tests.utils import (
+ TEST_NDARRAYS_ALL,
+ NumpyImageTestCase2D,
+ SkipIfAtLeastPyTorchVersion,
+ assert_allclose,
+ test_local_inversion,
+)
TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)]
@@ -58,8 +64,8 @@
]
+@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445
class TestResized(NumpyImageTestCase2D):
-
def test_invalid_inputs(self):
with self.assertRaises(ValueError):
resize = Resized(keys="img", spatial_size=(128, 128, 3), mode="order")
From 1394916fcec7c9fe7efa10549de3a3e76dc82db3 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Wed, 21 Feb 2024 21:30:25 +0800
Subject: [PATCH 68/88] Fix Python 3.12 import AttributeError (#7482)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7458
### Description
https://github.com/python/cpython/blob/a21c0c7def9a8495f1166d9b434dfc301cb92bff/Lib/importlib/abc.py#L68
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/utils/module.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/monai/utils/module.py b/monai/utils/module.py
index 0dcf22fcd7..5e058c105b 100644
--- a/monai/utils/module.py
+++ b/monai/utils/module.py
@@ -209,7 +209,7 @@ def load_submodules(
if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None:
try:
mod = import_module(name)
- importer.find_module(name).load_module(name) # type: ignore
+ importer.find_spec(name).loader.load_module(name) # type: ignore
submodules.append(mod)
except OptionalImportError:
pass # could not import the optional deps., they are ignored
From ff198228efd28c63b135c496dfc50e233b7bcbb6 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Thu, 22 Feb 2024 10:29:13 +0800
Subject: [PATCH 69/88] Update test_nnunetv2runner (#7483)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7013 #7478
### Description
replace `predict_from_raw_data` with `nnUNetPredictor` in
test_nnunetv2runner
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/apps/nnunet/nnunetv2_runner.py | 22 ++++++++++++----------
1 file changed, 12 insertions(+), 10 deletions(-)
diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py
index e62809403e..44b3c24256 100644
--- a/monai/apps/nnunet/nnunetv2_runner.py
+++ b/monai/apps/nnunet/nnunetv2_runner.py
@@ -37,6 +37,7 @@ class nnUNetV2Runner: # noqa: N801
"""
``nnUNetV2Runner`` provides an interface in MONAI to use `nnU-Net` V2 library to analyze, train, and evaluate
neural networks for medical image segmentation tasks.
+ A version of nnunetv2 higher than 2.2 is needed for this class.
``nnUNetV2Runner`` can be used in two ways:
@@ -770,7 +771,7 @@ def find_best_configuration(
def predict(
self,
list_of_lists_or_source_folder: str | list[list[str]],
- output_folder: str,
+ output_folder: str | None | list[str],
model_training_output_dir: str,
use_folds: tuple[int, ...] | str | None = None,
tile_step_size: float = 0.5,
@@ -824,7 +825,7 @@ def predict(
"""
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
- from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data
+ from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
n_processes_preprocessing = (
self.default_num_processes if num_processes_preprocessing < 0 else num_processes_preprocessing
@@ -832,20 +833,21 @@ def predict(
n_processes_segmentation_export = (
self.default_num_processes if num_processes_segmentation_export < 0 else num_processes_segmentation_export
)
-
- predict_from_raw_data(
- list_of_lists_or_source_folder=list_of_lists_or_source_folder,
- output_folder=output_folder,
- model_training_output_dir=model_training_output_dir,
- use_folds=use_folds,
+ predictor = nnUNetPredictor(
tile_step_size=tile_step_size,
use_gaussian=use_gaussian,
use_mirroring=use_mirroring,
- perform_everything_on_gpu=perform_everything_on_gpu,
+ perform_everything_on_device=perform_everything_on_gpu,
verbose=verbose,
+ )
+ predictor.initialize_from_trained_model_folder(
+ model_training_output_dir=model_training_output_dir, use_folds=use_folds, checkpoint_name=checkpoint_name
+ )
+ predictor.predict_from_files(
+ list_of_lists_or_source_folder=list_of_lists_or_source_folder,
+ output_folder_or_list_of_truncated_output_files=output_folder,
save_probabilities=save_probabilities,
overwrite=overwrite,
- checkpoint_name=checkpoint_name,
num_processes_preprocessing=n_processes_preprocessing,
num_processes_segmentation_export=n_processes_segmentation_export,
folder_with_segs_from_prev_stage=folder_with_segs_from_prev_stage,
From 5bbaab9d1b21672e6291931eae2333395b9dec17 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 23 Feb 2024 11:07:43 +0800
Subject: [PATCH 70/88] Fix github resource issue when build latest docker
(#7450)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7449
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/docker.yml | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 065125cc33..65716f86f9 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -17,7 +17,8 @@ jobs:
versioning_dev:
# compute versioning file from python setup.py
# upload as artifact
- if: github.repository == 'Project-MONAI/MONAI'
+ # if: github.repository == 'Project-MONAI/MONAI'
+ if: ${{ false }} # disable docker build job project-monai/monai#7450
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
@@ -47,8 +48,8 @@ jobs:
rm -rf {*,.[^.]*}
docker_build_dev:
- # builds projectmonai/monai:latest
- if: github.repository == 'Project-MONAI/MONAI'
+ # if: github.repository == 'Project-MONAI/MONAI'
+ if: ${{ false }} # disable docker build job project-monai/monai#7450
needs: versioning_dev
runs-on: ubuntu-latest
steps:
From 473593e417c0a434b984c742596fdcb6560b44a2 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Sat, 24 Feb 2024 00:49:00 +0800
Subject: [PATCH 71/88] Use int16 instead of int8 in `LabelStats` (#7489)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Use uint8 instead of int8 in `LabelStats`.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/auto3dseg/analyzer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py
index 56419da4cb..37f3faea21 100644
--- a/monai/auto3dseg/analyzer.py
+++ b/monai/auto3dseg/analyzer.py
@@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
torch.set_grad_enabled(False)
ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
- ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D)
+ ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
@@ -472,7 +472,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
unique_label = unique_label.data.cpu().numpy()
- unique_label = unique_label.astype(np.int8).tolist()
+ unique_label = unique_label.astype(np.int16).tolist()
label_substats = [] # each element is one label
pixel_sum = 0
From 01a8a2459400b79df681d2f0701d3368cb341804 Mon Sep 17 00:00:00 2001
From: monai-bot <64792179+monai-bot@users.noreply.github.com>
Date: Mon, 26 Feb 2024 07:01:44 +0000
Subject: [PATCH 72/88] auto updates (#7495)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: monai-bot
Signed-off-by: monai-bot
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_resize.py | 1 +
tests/test_resized.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/tests/test_resize.py b/tests/test_resize.py
index 65b33ea649..d4c57e2742 100644
--- a/tests/test_resize.py
+++ b/tests/test_resize.py
@@ -46,6 +46,7 @@
class TestResize(NumpyImageTestCase2D):
+
def test_invalid_inputs(self):
with self.assertRaises(ValueError):
resize = Resize(spatial_size=(128, 128, 3), mode="order")
diff --git a/tests/test_resized.py b/tests/test_resized.py
index d62f29ab5c..243a4e6622 100644
--- a/tests/test_resized.py
+++ b/tests/test_resized.py
@@ -66,6 +66,7 @@
@SkipIfAtLeastPyTorchVersion((2, 2, 0)) # https://github.com/Project-MONAI/MONAI/issues/7445
class TestResized(NumpyImageTestCase2D):
+
def test_invalid_inputs(self):
with self.assertRaises(ValueError):
resize = Resized(keys="img", spatial_size=(128, 128, 3), mode="order")
From b0c96d881bc64b62a05d58ca827c6447c8c058ed Mon Sep 17 00:00:00 2001
From: "Timothy J. Baker" <62781117+tim-the-baker@users.noreply.github.com>
Date: Mon, 26 Feb 2024 11:32:23 -0500
Subject: [PATCH 73/88] Add sample_std parameter to RandGaussianNoise. (#7492)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes issue #7425
### Description
Add a `sample_std` parameter to `RandGaussianNoise` and
`RandGaussianNoised`. When True, the Gaussian's standard deviation is
sampled uniformly from 0 to std (i.e., what is currently done). When
False, the noise's standard deviation is non-random and set to std. The
default for sample_std would be True for backwards compatibility.
Changes were based on RandRicianNoise which already has a `sample_std`
parameter and is similar to RandGaussianNoise in concept and
implementation.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Timothy Baker
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/intensity/array.py | 15 ++++++++++++---
monai/transforms/intensity/dictionary.py | 6 ++++--
tests/test_rand_gaussian_noise.py | 12 +++++++-----
tests/test_rand_gaussian_noised.py | 14 +++++++++-----
4 files changed, 32 insertions(+), 15 deletions(-)
diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py
index f9667402c9..a2f63a7482 100644
--- a/monai/transforms/intensity/array.py
+++ b/monai/transforms/intensity/array.py
@@ -91,24 +91,33 @@ class RandGaussianNoise(RandomizableTransform):
mean: Mean or “centre” of the distribution.
std: Standard deviation (spread) of distribution.
dtype: output data type, if None, same as input image. defaults to float32.
+ sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
- def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None:
+ def __init__(
+ self,
+ prob: float = 0.1,
+ mean: float = 0.0,
+ std: float = 0.1,
+ dtype: DtypeLike = np.float32,
+ sample_std: bool = True,
+ ) -> None:
RandomizableTransform.__init__(self, prob)
self.mean = mean
self.std = std
self.dtype = dtype
self.noise: np.ndarray | None = None
+ self.sample_std = sample_std
def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
- rand_std = self.R.uniform(0, self.std)
- noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape)
+ std = self.R.uniform(0, self.std) if self.sample_std else self.std
+ noise = self.R.normal(self.mean if mean is None else mean, std, size=img.shape)
# noise is float64 array, convert to the output dtype to save memory
self.noise, *_ = convert_data_type(noise, dtype=self.dtype)
diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py
index 058ef87b95..7e93464e64 100644
--- a/monai/transforms/intensity/dictionary.py
+++ b/monai/transforms/intensity/dictionary.py
@@ -172,7 +172,7 @@
class RandGaussianNoised(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`.
- Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add
+ Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if you want to add
different noise for every field, please use this transform separately.
Args:
@@ -183,6 +183,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform):
std: Standard deviation (spread) of distribution.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
+ sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.
"""
backend = RandGaussianNoise.backend
@@ -195,10 +196,11 @@ def __init__(
std: float = 0.1,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
+ sample_std: bool = True,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
- self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype)
+ self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std)
def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py
index a56e54fe31..233b4dd1b6 100644
--- a/tests/test_rand_gaussian_noise.py
+++ b/tests/test_rand_gaussian_noise.py
@@ -22,22 +22,24 @@
TESTS = []
for p in TEST_NDARRAYS:
- TESTS.append(("test_zero_mean", p, 0, 0.1))
- TESTS.append(("test_non_zero_mean", p, 1, 0.5))
+ TESTS.append(("test_zero_mean", p, 0, 0.1, True))
+ TESTS.append(("test_non_zero_mean", p, 1, 0.5, True))
+ TESTS.append(("test_no_sample_std", p, 1, 0.5, False))
class TestRandGaussianNoise(NumpyImageTestCase2D):
@parameterized.expand(TESTS)
- def test_correct_results(self, _, im_type, mean, std):
+ def test_correct_results(self, _, im_type, mean, std, sample_std):
seed = 0
- gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std)
+ gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std, sample_std=sample_std)
gaussian_fn.set_random_state(seed)
im = im_type(self.imt)
noised = gaussian_fn(im)
np.random.seed(seed)
np.random.random()
- expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
+ _std = np.random.uniform(0, std) if sample_std else std
+ expected = self.imt + np.random.normal(mean, _std, size=self.imt.shape)
if isinstance(noised, torch.Tensor):
noised = noised.cpu()
np.testing.assert_allclose(expected, noised, atol=1e-5)
diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py
index bcbed98b5a..e3df196be2 100644
--- a/tests/test_rand_gaussian_noised.py
+++ b/tests/test_rand_gaussian_noised.py
@@ -22,8 +22,9 @@
TESTS = []
for p in TEST_NDARRAYS:
- TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1])
- TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5])
+ TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1, True])
+ TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5, True])
+ TESTS.append(["test_no_sample_std", p, ["img1", "img2"], 1, 0.5, False])
seed = 0
@@ -31,15 +32,18 @@
class TestRandGaussianNoised(NumpyImageTestCase2D):
@parameterized.expand(TESTS)
- def test_correct_results(self, _, im_type, keys, mean, std):
- gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64)
+ def test_correct_results(self, _, im_type, keys, mean, std, sample_std):
+ gaussian_fn = RandGaussianNoised(
+ keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64, sample_std=sample_std
+ )
gaussian_fn.set_random_state(seed)
im = im_type(self.imt)
noised = gaussian_fn({k: im for k in keys})
np.random.seed(seed)
# simulate the randomize() of transform
np.random.random()
- noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
+ _std = np.random.uniform(0, std) if sample_std else std
+ noise = np.random.normal(mean, _std, size=self.imt.shape)
for k in keys:
expected = self.imt + noise
if isinstance(noised[k], torch.Tensor):
From 771af492286f89b28ba7f9934d16b80fab325af4 Mon Sep 17 00:00:00 2001
From: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com>
Date: Wed, 28 Feb 2024 06:10:40 +0100
Subject: [PATCH 74/88] Add __repr__ and __str__ to Metrics baseclass (#7487)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### Description
When training a model using MONAI metrics for experiment tracking, I
tend to log which metrics I am using. Unfortunately, just sending the
metrics objects to Tensorboard will result in a list like
[CustomMetric1, CustomMetric2, , etc.]
Adding `__repr__` and `__str__` methods to the base class will solve
this small annoyance. The current implementation will only return the
class name, but if a certain metric would wish to report more data for
its `__repr__` string, this can be easily overridden in any subclass.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Mathijs de Boer
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Mathijs de Boer
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/metrics/metric.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py
index a6dc1a49a2..249b2dc951 100644
--- a/monai/metrics/metric.py
+++ b/monai/metrics/metric.py
@@ -37,6 +37,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+ def __str__(self):
+ return self.__class__.__name__
+
class IterationMetric(Metric):
"""
From ee8bd4fe6a7dbcdf12d8d9aeb3f265a0d7598268 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Fri, 1 Mar 2024 16:23:09 +0800
Subject: [PATCH 75/88] Bump al-cheb/configure-pagefile-action from 1.3 to 1.4
(#7510)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps
[al-cheb/configure-pagefile-action](https://github.com/al-cheb/configure-pagefile-action)
from 1.3 to 1.4.
Release notes
Sourced from al-cheb/configure-pagefile-action's
releases.
v1.4: Update task node version to 20
configure-pagefile-action
This action is intended to configure Pagefile size and location for
Windows images in GitHub Actions.
Available parameters
| Argument |
Description |
Format |
Default value |
minimum-size |
Set minimum size of Pagefile |
2048MB, 4GB, 8GB and etc |
8GB |
maximum-size |
Set maximum size of Pagefile |
The same like minimum-size |
minimum-size |
disk-root |
Set disk root where Pagefile will be located |
C: or D: |
D: |
Usage
name: CI
on: [push]
jobs:
build:
runs-on: windows-latest
steps:
- name: configure Pagefile
uses: al-cheb/configure-pagefile-action@v1.4
with:
minimum-size: 8
- name: configure Pagefile
uses: al-cheb/configure-pagefile-action@v1.4
with:
minimum-size: 8
maximum-size: 16
disk-root: "D:"
License
The scripts and documentation in this project are released under the
MIT
License
Commits
a3b6ebd
Merge pull request #20
from mikehardy/mikehardy-patch-1
850626f
build(deps): bump javascript dependencies / forward-port as needed
e7aac1b
fix: use node 20
d940d24
build(deps): use v4 of setup-node action, use node 20
dfdc038
build(deps): bump actions/checkout from 3 to 4
- See full diff in compare
view
[](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
.github/workflows/conda.yml | 2 +-
.github/workflows/pythonapp.yml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml
index a387c77ebd..367a24cbde 100644
--- a/.github/workflows/conda.yml
+++ b/.github/workflows/conda.yml
@@ -26,7 +26,7 @@ jobs:
steps:
- if: runner.os == 'windows'
name: Config pagefile (Windows only)
- uses: al-cheb/configure-pagefile-action@v1.3
+ uses: al-cheb/configure-pagefile-action@v1.4
with:
minimum-size: 8GB
maximum-size: 16GB
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index b011e65cf1..b7f2cfb9db 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -62,7 +62,7 @@ jobs:
steps:
- if: runner.os == 'windows'
name: Config pagefile (Windows only)
- uses: al-cheb/configure-pagefile-action@v1.3
+ uses: al-cheb/configure-pagefile-action@v1.4
with:
minimum-size: 8GB
maximum-size: 16GB
From 55be1d0226fae80dbefb1d9587f01d6d45ab909b Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Sun, 3 Mar 2024 23:09:21 +0800
Subject: [PATCH 76/88] Add arm support (#7500)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes # .
### Description
Add arm support
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
Dockerfile | 4 ++++
requirements-dev.txt | 4 ++--
setup.cfg | 4 ++--
tests/test_convert_to_onnx.py | 19 +++++++++++++------
tests/test_dynunet.py | 9 ++++++++-
tests/test_rand_affine.py | 2 +-
tests/test_rand_affined.py | 4 +++-
tests/test_spatial_resampled.py | 9 ++++++++-
8 files changed, 41 insertions(+), 14 deletions(-)
diff --git a/Dockerfile b/Dockerfile
index cb1300ea90..7383837585 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -16,6 +16,10 @@ FROM ${PYTORCH_IMAGE}
LABEL maintainer="monai.contact@gmail.com"
+# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
+WORKDIR /opt
+RUN git clone --recursive https://github.com/zarr-developers/numcodecs.git && pip wheel numcodecs
+
WORKDIR /opt/monai
# install full deps
diff --git a/requirements-dev.txt b/requirements-dev.txt
index b08fef874b..af1b8b89d5 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -26,7 +26,7 @@ mypy>=1.5.0
ninja
torchvision
psutil
-cucim>=23.2.0; platform_system == "Linux"
+cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
openslide-python
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
tifffile; platform_system == "Linux" or platform_system == "Darwin"
@@ -46,7 +46,7 @@ pynrrd
pre-commit
pydicom
h5py
-nni; platform_system == "Linux"
+nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
onnx>=1.13.0
diff --git a/setup.cfg b/setup.cfg
index 229e2ace56..d7cb703d25 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -59,7 +59,7 @@ all =
tqdm>=4.47.0
lmdb
psutil
- cucim>=23.2.0
+ cucim-cu12; python_version >= '3.9' and python_version <= '3.10'
openslide-python
tifffile
imagecodecs
@@ -111,7 +111,7 @@ lmdb =
psutil =
psutil
cucim =
- cucim>=23.2.0
+ cucim-cu12
openslide =
openslide-python
tifffile =
diff --git a/tests/test_convert_to_onnx.py b/tests/test_convert_to_onnx.py
index 398d260c52..798c510800 100644
--- a/tests/test_convert_to_onnx.py
+++ b/tests/test_convert_to_onnx.py
@@ -12,6 +12,7 @@
from __future__ import annotations
import itertools
+import platform
import unittest
import torch
@@ -29,6 +30,12 @@
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))
+ON_AARCH64 = platform.machine() == "aarch64"
+if ON_AARCH64:
+ rtol, atol = 1e-1, 1e-2
+else:
+ rtol, atol = 1e-3, 1e-4
+
onnx, _ = optional_import("onnx")
@@ -56,8 +63,8 @@ def test_unet(self, device, use_trace, use_ort):
device=device,
use_ort=use_ort,
use_trace=use_trace,
- rtol=1e-3,
- atol=1e-4,
+ rtol=rtol,
+ atol=atol,
)
else:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
@@ -72,8 +79,8 @@ def test_unet(self, device, use_trace, use_ort):
device=device,
use_ort=use_ort,
use_trace=use_trace,
- rtol=1e-3,
- atol=1e-4,
+ rtol=rtol,
+ atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
@@ -107,8 +114,8 @@ def test_seg_res_net(self, device, use_ort):
device=device,
use_ort=use_ort,
use_trace=True,
- rtol=1e-3,
- atol=1e-4,
+ rtol=rtol,
+ atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py
index b0137ae245..f3c982056c 100644
--- a/tests/test_dynunet.py
+++ b/tests/test_dynunet.py
@@ -11,6 +11,7 @@
from __future__ import annotations
+import platform
import unittest
from typing import Any, Sequence
@@ -24,6 +25,12 @@
InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
+ON_AARCH64 = platform.machine() == "aarch64"
+if ON_AARCH64:
+ rtol, atol = 1e-2, 1e-2
+else:
+ rtol, atol = 1e-4, 1e-4
+
device = "cuda" if torch.cuda.is_available() else "cpu"
strides: Sequence[Sequence[int] | int]
@@ -159,7 +166,7 @@ def test_consistency(self, input_param, input_shape, _):
with eval_mode(net_fuser):
result_fuser = net_fuser(input_tensor)
- assert_allclose(result, result_fuser, rtol=1e-4, atol=1e-4)
+ assert_allclose(result, result_fuser, rtol=rtol, atol=atol)
class TestDynUNetDeepSupervision(unittest.TestCase):
diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py
index f37f7827bb..23e3fd148c 100644
--- a/tests/test_rand_affine.py
+++ b/tests/test_rand_affine.py
@@ -147,7 +147,7 @@ def test_rand_affine(self, input_param, input_data, expected_val):
g.set_random_state(123)
result = g(**input_data)
g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64) # reset affine
- test_resampler_lazy(g, result, input_param, input_data, seed=123)
+ test_resampler_lazy(g, result, input_param, input_data, seed=123, rtol=_rtol)
if input_param.get("cache_grid", False):
self.assertTrue(g._cached_grid is not None)
assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor")
diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py
index 20c50954e2..32fde8dc0f 100644
--- a/tests/test_rand_affined.py
+++ b/tests/test_rand_affined.py
@@ -234,7 +234,9 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
lazy_init_param["keys"], lazy_init_param["mode"] = key, mode
resampler = RandAffined(**lazy_init_param).set_random_state(123)
expected_output = resampler(**call_param)
- test_resampler_lazy(resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key)
+ test_resampler_lazy(
+ resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key, rtol=_rtol
+ )
resampler.lazy = False
if input_param.get("cache_grid", False):
diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py
index 541015cc34..d5c86258d7 100644
--- a/tests/test_spatial_resampled.py
+++ b/tests/test_spatial_resampled.py
@@ -11,6 +11,7 @@
from __future__ import annotations
+import platform
import unittest
import numpy as np
@@ -23,6 +24,12 @@
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.utils import TEST_DEVICES, assert_allclose
+ON_AARCH64 = platform.machine() == "aarch64"
+if ON_AARCH64:
+ rtol, atol = 1e-1, 1e-2
+else:
+ rtol, atol = 1e-3, 1e-4
+
TESTS = []
destinations_3d = [
@@ -104,7 +111,7 @@ def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):
# check lazy
lazy_xform = SpatialResampled(**init_param)
- test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img")
+ test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key="img", rtol=rtol, atol=atol)
# check inverse
inverted = xform.inverse(output_data)["img"]
From 6ad169a830c828a9ea4c13ab561bd5a29d957a22 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Sun, 10 Mar 2024 23:21:28 +0800
Subject: [PATCH 77/88] Fix error in "test_bundle_trt_export" (#7524)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7523
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/networks/utils.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/monai/networks/utils.py b/monai/networks/utils.py
index 42e537648a..4e6699f16b 100644
--- a/monai/networks/utils.py
+++ b/monai/networks/utils.py
@@ -850,7 +850,10 @@ def _onnx_trt_compile(
# wrap the serialized TensorRT engine back to a TorchScript module.
trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
- f.getvalue(), torch.device(f"cuda:{device}"), input_names, output_names
+ f.getvalue(),
+ device=torch.device(f"cuda:{device}"),
+ input_binding_names=input_names,
+ output_binding_names=output_names,
)
return trt_model
From 9f57cb2c92217f08ffa858da2bf192d4e299df5d Mon Sep 17 00:00:00 2001
From: Fabian Klopfer
Date: Fri, 15 Mar 2024 03:19:57 +0100
Subject: [PATCH 78/88] Fix typo in the PerceptualNetworkType Enum (#7548)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7547
### Description
Previously it was 'medical_resnet50_23datasets' for both identifier and
string, which doesn't correspond to the name in the hubconf.py of
Warvito's repo. Now it is the correct version (according to Warvitos
repo) 'medicalnet_resnet50_23datasets'.
The docs state it correctly already.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ x] New tests added to cover the changes.
- [ x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: Fabian Klopfer
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/losses/perceptual.py | 2 +-
tests/test_perceptual_loss.py | 5 +++++
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py
index 2207de5e64..fd61603b03 100644
--- a/monai/losses/perceptual.py
+++ b/monai/losses/perceptual.py
@@ -29,7 +29,7 @@ class PercetualNetworkType(StrEnum):
squeeze = "squeeze"
radimagenet_resnet50 = "radimagenet_resnet50"
medicalnet_resnet10_23datasets = "medicalnet_resnet10_23datasets"
- medical_resnet50_23datasets = "medical_resnet50_23datasets"
+ medicalnet_resnet50_23datasets = "medicalnet_resnet50_23datasets"
resnet50 = "resnet50"
diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py
index ba204af697..02232e6f8d 100644
--- a/tests/test_perceptual_loss.py
+++ b/tests/test_perceptual_loss.py
@@ -40,6 +40,11 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
+ [
+ {"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
+ (2, 1, 64, 64, 64),
+ (2, 1, 64, 64, 64),
+ ],
[
{"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
(2, 1, 64, 64, 64),
From 5465ae3e038e7fc3c8e39bd68d37731a1e764b56 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Mon, 18 Mar 2024 11:56:45 +0800
Subject: [PATCH 79/88] Update to use `log_sigmoid` in `FocalLoss` (#7534)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7533
### Description
A few sentences describing the changes proposed in this pull request.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/losses/focal_loss.py | 5 ++---
tests/test_focal_loss.py | 2 +-
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py
index 98c1a071b6..28d1c0cdc9 100644
--- a/monai/losses/focal_loss.py
+++ b/monai/losses/focal_loss.py
@@ -234,9 +234,8 @@ def sigmoid_focal_loss(
"""
# computing binary cross entropy with logits
# equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
- # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
- max_val = (-input).clamp(min=0)
- loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
+ # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363
+ loss: torch.Tensor = input - input * target - F.logsigmoid(input)
# sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
# 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py
index de8d625058..0bb8a078ae 100644
--- a/tests/test_focal_loss.py
+++ b/tests/test_focal_loss.py
@@ -132,7 +132,7 @@ def test_consistency_with_cross_entropy_2d_no_reduction(self):
error = np.abs(a - b)
max_error = np.maximum(error, max_error)
- assert np.allclose(max_error, 0)
+ assert np.allclose(max_error, 0, atol=1e-6)
def test_consistency_with_cross_entropy_2d_onehot_label(self):
"""For gamma=0 the focal loss reduces to the cross entropy loss"""
From e4a83467630f2f74b2e1760bf3a8fd71b6cd5009 Mon Sep 17 00:00:00 2001
From: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Date: Fri, 22 Mar 2024 16:03:00 +0800
Subject: [PATCH 80/88] Update integration_segmentation_3d result for
PyTorch2403 (#7551)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Fixes #7550
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
runtests.sh | 2 +
tests/testing_data/integration_answers.py | 56 +++++++++++++++++++++++
2 files changed, 58 insertions(+)
diff --git a/runtests.sh b/runtests.sh
index 0c60bc0f58..0b3e20ce49 100755
--- a/runtests.sh
+++ b/runtests.sh
@@ -738,12 +738,14 @@ fi
# network training/inference/eval integration tests
if [ $doNetTests = true ]
then
+ set +e # disable exit on failure so that diagnostics can be given on failure
echo "${separator}${blue}integration${noColor}"
for i in tests/*integration_*.py
do
echo "$i"
${cmdPrefix}${cmd} "$i"
done
+ set -e # enable exit on failure
fi
# run model zoo tests
diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py
index c0dd973418..e02b9ae995 100644
--- a/tests/testing_data/integration_answers.py
+++ b/tests/testing_data/integration_answers.py
@@ -600,6 +600,62 @@
],
}
},
+ { # test answers for 24.03
+ "integration_segmentation_3d": {
+ "losses": [
+ 0.5442982316017151,
+ 0.4741817444562912,
+ 0.4535954713821411,
+ 0.44163046181201937,
+ 0.4307525992393494,
+ 0.428487154841423,
+ ],
+ "best_metric": 0.9314384460449219,
+ "infer_metric": 0.9315622448921204,
+ "output_sums": [
+ 0.14268704426414708,
+ 0.1528672845845743,
+ 0.1521782248125706,
+ 0.14028769128068194,
+ 0.1889830671664784,
+ 0.16999075690664475,
+ 0.14736282992708227,
+ 0.16877952654821815,
+ 0.15779597155181269,
+ 0.17987829927082263,
+ 0.16320253928314676,
+ 0.16854299322173155,
+ 0.14497470986956967,
+ 0.11437140546369519,
+ 0.1624117412960871,
+ 0.20156009294443875,
+ 0.1764654154256958,
+ 0.0982348259217418,
+ 0.1942436068604293,
+ 0.20359421536407518,
+ 0.19661953116976483,
+ 0.2088326101468625,
+ 0.16273043545239807,
+ 0.1326107887439663,
+ 0.1489245275752285,
+ 0.143107476635514,
+ 0.23189027677929547,
+ 0.1613818424566088,
+ 0.14889532196775188,
+ 0.10332622984492143,
+ 0.11940054688302351,
+ 0.13040496302762658,
+ 0.11472123087193181,
+ 0.15307044007394474,
+ 0.16371989575844717,
+ 0.1942898223272055,
+ 0.2230120930471398,
+ 0.1814679187634795,
+ 0.19069496508164732,
+ 0.07537197031940022,
+ ],
+ }
+ },
]
From a85d6a996032d381ce90867d93819d173e8ae610 Mon Sep 17 00:00:00 2001
From: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com>
Date: Fri, 22 Mar 2024 11:58:39 +0100
Subject: [PATCH 81/88] Add Barlow Twins loss for representation learning
(#7530)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### Description
Addition of the BarlowTwinsLoss class. This cost function is introduced
in the http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf paper
with the aim of disentangling the representations learned on two views
of the same sample, making it a powerful tool for multimodal and
unsupervised learning.
This cost function is similar to the InfoNCE Loss function already
implemented in MONAI
(https://docs.monai.io/en/latest/_modules/monai/losses/contrastive.html#ContrastiveLoss).
However, it differs in several respects: there is no l2-normalisation,
but rather a z-normalisation. In addition, rather than working between
pairs of embeddings, Barlow Twins seeks to decorrelate the components of
the representations.
```math
\mathcal{L}_{BT} := \sum_i (1 - \mathcal{C}_{ii})^2 + \lambda \sum_i \sum_{i\neq j} \mathcal{C}_{ij}^2
```
with $\lambda$ a positive hyperparameters and $\mathcal{C}$ the
cross-correlation matrix
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Lucas Robinet
Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com>
Co-authored-by: Lucas Robinet
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/losses.rst | 5 ++
monai/losses/__init__.py | 1 +
monai/losses/barlow_twins.py | 84 ++++++++++++++++++++++++
tests/test_barlow_twins_loss.py | 109 ++++++++++++++++++++++++++++++++
4 files changed, 199 insertions(+)
create mode 100644 monai/losses/barlow_twins.py
create mode 100644 tests/test_barlow_twins_loss.py
diff --git a/docs/source/losses.rst b/docs/source/losses.rst
index e929e9d605..61dd959807 100644
--- a/docs/source/losses.rst
+++ b/docs/source/losses.rst
@@ -73,6 +73,11 @@ Segmentation Losses
.. autoclass:: ContrastiveLoss
:members:
+`BarlowTwinsLoss`
+~~~~~~~~~~~~~~~~~
+.. autoclass:: BarlowTwinsLoss
+ :members:
+
`HausdorffDTLoss`
~~~~~~~~~~~~~~~~~
.. autoclass:: HausdorffDTLoss
diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py
index 92898c81ca..4ebedb2084 100644
--- a/monai/losses/__init__.py
+++ b/monai/losses/__init__.py
@@ -12,6 +12,7 @@
from __future__ import annotations
from .adversarial_loss import PatchAdversarialLoss
+from .barlow_twins import BarlowTwinsLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss, DiffusionLoss
diff --git a/monai/losses/barlow_twins.py b/monai/losses/barlow_twins.py
new file mode 100644
index 0000000000..a61acca66e
--- /dev/null
+++ b/monai/losses/barlow_twins.py
@@ -0,0 +1,84 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import torch
+from torch.nn.modules.loss import _Loss
+
+
+class BarlowTwinsLoss(_Loss):
+ """
+ The Barlow Twins cost function takes the representations extracted by a neural network from two
+ distorted views and seeks to make the cross-correlation matrix of the two representations tend
+ towards identity. This encourages the neural network to learn similar representations with the least
+ amount of redundancy. This cost function can be used in particular in multimodal learning to work on
+ representations from two modalities. The most common use case is for unsupervised learning, where data
+ augmentations are used to generate 2 distorted views of the same sample to force the encoder to
+ extract useful features for downstream tasks.
+
+ Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International
+ conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf)
+
+ Adapted from:
+ https://github.com/facebookresearch/barlowtwins
+
+ """
+
+ def __init__(self, lambd: float = 5e-3) -> None:
+ """
+ Args:
+ lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3.
+
+ Raises:
+ ValueError: When an input of dimension length > 2 is passed
+ ValueError: When input and target are of different shapes
+ ValueError: When batch size is less than or equal to 1
+
+ """
+ super().__init__()
+ self.lambd = lambd
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ input: the shape should be B[F].
+ target: the shape should be B[F].
+ """
+ if len(target.shape) > 2 or len(input.shape) > 2:
+ raise ValueError(
+ f"Either target or input has dimensions greater than 2 where target "
+ f"shape is ({target.shape}) and input shape is ({input.shape})"
+ )
+
+ if target.shape != input.shape:
+ raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
+
+ if target.size(0) <= 1:
+ raise ValueError(
+ f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}"
+ )
+
+ lambd_tensor = torch.as_tensor(self.lambd).to(input.device)
+ batch_size = input.shape[0]
+
+ # normalize input and target
+ input_norm = (input - input.mean(0)) / input.std(0).add(1e-6)
+ target_norm = (target - target.mean(0)) / target.std(0).add(1e-6)
+
+ # cross-correlation matrix
+ c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF
+
+ # loss
+ c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF
+ c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor
+
+ return c_diff.sum()
diff --git a/tests/test_barlow_twins_loss.py b/tests/test_barlow_twins_loss.py
new file mode 100644
index 0000000000..81f4032e0c
--- /dev/null
+++ b/tests/test_barlow_twins_loss.py
@@ -0,0 +1,109 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.losses import BarlowTwinsLoss
+
+TEST_CASES = [
+ [ # shape: (2, 4), (2, 4)
+ {"lambd": 5e-3},
+ {
+ "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
+ "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
+ },
+ 4.0,
+ ],
+ [ # shape: (2, 4), (2, 4)
+ {"lambd": 5e-3},
+ {
+ "input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]),
+ "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
+ },
+ 4.0,
+ ],
+ [ # shape: (2, 4), (2, 4)
+ {"lambd": 5e-3},
+ {
+ "input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]),
+ "target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]),
+ },
+ 5.2562,
+ ],
+ [ # shape: (2, 4), (2, 4)
+ {"lambd": 5e-4},
+ {
+ "input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]),
+ "target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]),
+ },
+ 5.0015,
+ ],
+ [ # shape: (4, 4), (4, 4)
+ {"lambd": 5e-3},
+ {
+ "input": torch.tensor(
+ [[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]]
+ ),
+ "target": torch.tensor(
+ [
+ [0.0, 1.0, -1.0, 0.0],
+ [1 / 3, 0.0, -2 / 3, 1 / 3],
+ [-2 / 3, -1.0, 7 / 3, 1 / 3],
+ [1 / 3, 0.0, 1 / 3, -2 / 3],
+ ]
+ ),
+ },
+ 1.4736,
+ ],
+]
+
+
+class TestBarlowTwinsLoss(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES)
+ def test_result(self, input_param, input_data, expected_val):
+ barlowtwinsloss = BarlowTwinsLoss(**input_param)
+ result = barlowtwinsloss(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
+
+ def test_ill_shape(self):
+ loss = BarlowTwinsLoss(lambd=5e-3)
+ with self.assertRaises(ValueError):
+ loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+
+ def test_ill_batch_size(self):
+ loss = BarlowTwinsLoss(lambd=5e-3)
+ with self.assertRaises(ValueError):
+ loss(torch.ones((1, 2)), torch.ones((1, 2)))
+
+ def test_with_cuda(self):
+ loss = BarlowTwinsLoss(lambd=5e-3)
+ i = torch.ones((2, 10))
+ j = torch.ones((2, 10))
+ if torch.cuda.is_available():
+ i = i.cuda()
+ j = j.cuda()
+ output = loss(i, j)
+ np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4)
+
+ def check_warning_raised(self):
+ with self.assertWarns(Warning):
+ BarlowTwinsLoss(lambd=5e-3, batch_size=1)
+
+
+if __name__ == "__main__":
+ unittest.main()
From 1916a4164d158a64f2d1e2a75791461fa4181097 Mon Sep 17 00:00:00 2001
From: cxlcl
Date: Fri, 22 Mar 2024 09:54:40 -0700
Subject: [PATCH 82/88] Stein's Unbiased Risk Estimator (SURE) loss and
Conjugate Gradient (#7308)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### Description
Based on the discussion topic
[here](https://github.com/Project-MONAI/MONAI/discussions/7161#discussion-5773293),
we implemented the Conjugate-Gradient algorithm for linear operator
inversion, and Stein's Unbiased Risk Estimator (SURE) [1] loss for
ground-truth-date free diffusion process guidance that is proposed in
[2] and illustrated in the algorithm below:
The Conjugate-Gradient (CG) algorithm is used to solve for the inversion
of the linear operator in Line-4 in the algorithm above, where the
linear operator is too large to store explicitly as a matrix (such as
FFT/IFFT of an image) and invert directly. Instead, we can solve for the
linear inversion iteratively as in CG.
The SURE loss is applied for Line-6 above. This is a differentiable loss
function that can be used to train/giude an operator (e.g. neural
network), where the pseudo ground truth is available but the reference
ground truth is not. For example, in the MRI reconstruction, the pseudo
ground truth is the zero-filled reconstruction and the reference ground
truth is the fully sampled reconstruction. The reference ground truth is
not available due to the lack of fully sampled.
**Reference**
[1] Stein, C.M.: Estimation of the mean of a multivariate normal
distribution. Annals of Statistics 1981 [[paper
link](https://projecteuclid.org/journals/annals-of-statistics/volume-9/issue-6/Estimation-of-the-Mean-of-a-Multivariate-Normal-Distribution/10.1214/aos/1176345632.full)]
[2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with
Diffusion Models. MICCAI 2023
[[paper link](https://arxiv.org/pdf/2310.01799.pdf)]
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: chaoliu
Signed-off-by: cxlcl
Signed-off-by: chaoliu
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
docs/source/losses.rst | 5 +
docs/source/networks.rst | 5 +
monai/losses/__init__.py | 1 +
monai/losses/sure_loss.py | 200 ++++++++++++++++++++
monai/networks/layers/__init__.py | 1 +
monai/networks/layers/conjugate_gradient.py | 112 +++++++++++
tests/test_conjugate_gradient.py | 55 ++++++
tests/test_sure_loss.py | 71 +++++++
8 files changed, 450 insertions(+)
create mode 100644 monai/losses/sure_loss.py
create mode 100644 monai/networks/layers/conjugate_gradient.py
create mode 100644 tests/test_conjugate_gradient.py
create mode 100644 tests/test_sure_loss.py
diff --git a/docs/source/losses.rst b/docs/source/losses.rst
index 61dd959807..ba794af3eb 100644
--- a/docs/source/losses.rst
+++ b/docs/source/losses.rst
@@ -139,6 +139,11 @@ Reconstruction Losses
.. autoclass:: JukeboxLoss
:members:
+`SURELoss`
+~~~~~~~~~~
+.. autoclass:: SURELoss
+ :members:
+
Loss Wrappers
-------------
diff --git a/docs/source/networks.rst b/docs/source/networks.rst
index 8eada7933f..b59c8af5fc 100644
--- a/docs/source/networks.rst
+++ b/docs/source/networks.rst
@@ -408,6 +408,11 @@ Layers
.. autoclass:: LLTM
:members:
+`ConjugateGradient`
+~~~~~~~~~~~~~~~~~~~
+.. autoclass:: ConjugateGradient
+ :members:
+
`Utilities`
~~~~~~~~~~~
.. automodule:: monai.networks.layers.convutils
diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py
index 4ebedb2084..e937b53fa4 100644
--- a/monai/losses/__init__.py
+++ b/monai/losses/__init__.py
@@ -41,5 +41,6 @@
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
from .ssim_loss import SSIMLoss
+from .sure_loss import SURELoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py
new file mode 100644
index 0000000000..ebf25613a6
--- /dev/null
+++ b/monai/losses/sure_loss.py
@@ -0,0 +1,200 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+
+
+def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """
+ First compute the difference in the complex domain,
+ then get the absolute value and take the mse
+
+ Args:
+ x, y - B, 2, H, W real valued tensors representing complex numbers
+ or B,1,H,W complex valued tensors
+ Returns:
+ l2_loss - scalar
+ """
+ if not x.is_complex():
+ x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous())
+ if not y.is_complex():
+ y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous())
+
+ diff = torch.abs(x - y)
+ return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean")
+
+
+def sure_loss_function(
+ operator: Callable,
+ x: torch.Tensor,
+ y_pseudo_gt: torch.Tensor,
+ y_ref: Optional[torch.Tensor] = None,
+ eps: Optional[float] = -1.0,
+ perturb_noise: Optional[torch.Tensor] = None,
+ complex_input: Optional[bool] = False,
+) -> torch.Tensor:
+ """
+ Args:
+ operator (function): The operator function that takes in an input
+ tensor x and returns an output tensor y. We will use this to compute
+ the divergence. More specifically, we will perturb the input x by a
+ small amount and compute the divergence between the perturbed output
+ and the reference output
+
+ x (torch.Tensor): The input tensor of shape (B, C, H, W) to the
+ operator. For complex input, the shape is (B, 2, H, W) aka C=2 real.
+ For real input, the shape is (B, 1, H, W) real.
+
+ y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape
+ (B, C, H, W) used to compute the L2 loss. For complex input, the shape is
+ (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W)
+ real.
+
+ y_ref (torch.Tensor, optional): The reference output tensor of shape
+ (B, C, H, W) used to compute the divergence. Defaults to None. For
+ complex input, the shape is (B, 2, H, W) aka C=2 real. For real input,
+ the shape is (B, 1, H, W) real.
+
+ eps (float, optional): The perturbation scalar. Set to -1 to set it
+ automatically estimated based on y_pseudo_gtk
+
+ perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W).
+ Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real.
+ For real input, the shape is (B, 1, H, W) real.
+
+ complex_input(bool, optional): Whether the input is complex or not.
+ Defaults to False.
+
+ Returns:
+ sure_loss (torch.Tensor): The SURE loss scalar.
+ """
+ # perturb input
+ if perturb_noise is None:
+ perturb_noise = torch.randn_like(x)
+ if eps == -1.0:
+ eps = float(torch.abs(y_pseudo_gt.max())) / 1000
+ # get y_ref if not provided
+ if y_ref is None:
+ y_ref = operator(x)
+
+ # get perturbed output
+ x_perturbed = x + eps * perturb_noise
+ y_perturbed = operator(x_perturbed)
+ # divergence
+ divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
+ # l2 loss between y_ref, y_pseudo_gt
+ if complex_input:
+ l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt)
+ else:
+ # real input
+ l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean")
+
+ # sure loss
+ sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3])
+ return sure_loss
+
+
+class SURELoss(_Loss):
+ """
+ Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator.
+
+ This is a differentiable loss function that can be used to train/guide an
+ operator (e.g. neural network), where the pseudo ground truth is available
+ but the reference ground truth is not. For example, in the MRI
+ reconstruction, the pseudo ground truth is the zero-filled reconstruction
+ and the reference ground truth is the fully sampled reconstruction. Often,
+ the reference ground truth is not available due to the lack of fully sampled
+ data.
+
+ The original SURE loss is proposed in [1]. The SURE loss used for guiding
+ the diffusion model based MRI reconstruction is proposed in [2].
+
+ Reference
+
+ [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics
+
+ [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models.
+ (https://arxiv.org/pdf/2310.01799.pdf)
+ """
+
+ def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None:
+ """
+ Args:
+ perturb_noise (torch.Tensor, optional): The noise vector of shape
+ (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real.
+ For real input, the shape is (B, 1, H, W) real.
+
+ eps (float, optional): The perturbation scalar. Defaults to None.
+ """
+ super().__init__()
+ self.perturb_noise = perturb_noise
+ self.eps = eps
+
+ def forward(
+ self,
+ operator: Callable,
+ x: torch.Tensor,
+ y_pseudo_gt: torch.Tensor,
+ y_ref: Optional[torch.Tensor] = None,
+ complex_input: Optional[bool] = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ operator (function): The operator function that takes in an input
+ tensor x and returns an output tensor y. We will use this to compute
+ the divergence. More specifically, we will perturb the input x by a
+ small amount and compute the divergence between the perturbed output
+ and the reference output
+
+ x (torch.Tensor): The input tensor of shape (B, C, H, W) to the
+ operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka
+ C=2 real. For real input, the shape is (B, 1, H, W) real.
+
+ y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape
+ (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex
+ input, the shape is (B, 2, H, W) aka C=2 real. For real input, the
+ shape is (B, 1, H, W) real.
+
+ y_ref (torch.Tensor, optional): The reference output tensor of the
+ same shape as y_pseudo_gt
+
+ Returns:
+ sure_loss (torch.Tensor): The SURE loss scalar.
+ """
+
+ # check inputs shapes
+ if x.dim() != 4:
+ raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.")
+ if y_pseudo_gt.dim() != 4:
+ raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.")
+ if y_ref is not None and y_ref.dim() != 4:
+ raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.")
+ if x.shape != y_pseudo_gt.shape:
+ raise ValueError(
+ f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, "
+ f"y_pseudo_gt shape {y_pseudo_gt.shape}."
+ )
+ if y_ref is not None and y_pseudo_gt.shape != y_ref.shape:
+ raise ValueError(
+ f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, "
+ f"y_ref shape {y_ref.shape}."
+ )
+
+ # compute loss
+ loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input)
+
+ return loss
diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py
index d61ed57f7f..3a6e4aa554 100644
--- a/monai/networks/layers/__init__.py
+++ b/monai/networks/layers/__init__.py
@@ -11,6 +11,7 @@
from __future__ import annotations
+from .conjugate_gradient import ConjugateGradient
from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
from .drop_path import DropPath
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py
new file mode 100644
index 0000000000..93a45930d7
--- /dev/null
+++ b/monai/networks/layers/conjugate_gradient.py
@@ -0,0 +1,112 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Callable
+
+import torch
+from torch import nn
+
+
+def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ """
+ Complex dot product between tensors x1 and x2: sum(x1.*x2)
+ """
+ if torch.is_complex(x1):
+ assert torch.is_complex(x2), "x1 and x2 must both be complex"
+ return torch.sum(x1.conj() * x2)
+ else:
+ return torch.sum(x1 * x2)
+
+
+def _zdot_single(x: torch.Tensor) -> torch.Tensor:
+ """
+ Complex dot product between tensor x and itself
+ """
+ res = _zdot(x, x)
+ if torch.is_complex(res):
+ return res.real
+ else:
+ return res
+
+
+class ConjugateGradient(nn.Module):
+ """
+ Congugate Gradient (CG) solver for linear systems Ax = y.
+
+ For linear_op that is positive definite and self-adjoint, CG is
+ guaranteed to converge CG is often used to solve linear systems of the form
+ Ax = y, where A is too large to store explicitly, but can be computed via a
+ linear operator.
+
+ As a result, here we won't set A explicitly as a matrix, but rather as a
+ linear operator. For example, A could be a FFT/IFFT operation
+ """
+
+ def __init__(self, linear_op: Callable, num_iter: int):
+ """
+ Args:
+ linear_op: Linear operator
+ num_iter: Number of iterations to run CG
+ """
+ super().__init__()
+
+ self.linear_op = linear_op
+ self.num_iter = num_iter
+
+ def update(
+ self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ perform one iteration of the CG method. It takes the current solution x,
+ the current search direction p, the current residual r, and the old
+ residual norm rsold as inputs. Then it computes the new solution, search
+ direction, residual, and residual norm, and returns them.
+ """
+
+ dy = self.linear_op(p)
+ p_dot_dy = _zdot(p, dy)
+ alpha = rsold / p_dot_dy
+ x = x + alpha * p
+ r = r - alpha * dy
+ rsnew = _zdot_single(r)
+ beta = rsnew / rsold
+ rsold = rsnew
+ p = beta * p + r
+ return x, p, r, rsold
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """
+ run conjugate gradient for num_iter iterations to solve Ax = y
+
+ Args:
+ x: tensor (real or complex); Initial guess for linear system Ax = y.
+ The size of x should be applicable to the linear operator. For
+ example, if the linear operator is FFT, then x is HCHW; if the
+ linear operator is a matrix multiplication, then x is a vector
+
+ y: tensor (real or complex); Measurement. Same size as x
+
+ Returns:
+ x: Solution to Ax = y
+ """
+ # Compute residual
+ r = y - self.linear_op(x)
+ rsold = _zdot_single(r)
+ p = r
+
+ # Update
+ for _i in range(self.num_iter):
+ x, p, r, rsold = self.update(x, p, r, rsold)
+ if rsold < 1e-10:
+ break
+ return x
diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py
new file mode 100644
index 0000000000..239dbe3ecd
--- /dev/null
+++ b/tests/test_conjugate_gradient.py
@@ -0,0 +1,55 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+
+from monai.networks.layers import ConjugateGradient
+
+
+class TestConjugateGradient(unittest.TestCase):
+ def test_real_valued_inverse(self):
+ """Test ConjugateGradient with real-valued input: when the input is real
+ value, the output should be the inverse of the matrix."""
+ a_dim = 3
+ a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float)
+
+ def a_op(x):
+ return a_mat @ x
+
+ cg_solver = ConjugateGradient(a_op, num_iter=100)
+ # define the measurement
+ y = torch.tensor([1, 2, 3], dtype=torch.float)
+ # solve for x
+ x = cg_solver(torch.zeros(a_dim), y)
+ x_ref = torch.linalg.solve(a_mat, y)
+ # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution'
+ self.assertTrue(torch.allclose(x, x_ref, atol=1e-6))
+
+ def test_complex_valued_inverse(self):
+ a_dim = 3
+ a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64)
+
+ def a_op(x):
+ return a_mat @ x
+
+ cg_solver = ConjugateGradient(a_op, num_iter=100)
+ y = torch.tensor([1, 2, 3], dtype=torch.complex64)
+ x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y)
+ x_ref = torch.linalg.solve(a_mat, y)
+ self.assertTrue(torch.allclose(x, x_ref, atol=1e-6))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py
new file mode 100644
index 0000000000..945da657bf
--- /dev/null
+++ b/tests/test_sure_loss.py
@@ -0,0 +1,71 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+
+from monai.losses import SURELoss
+
+
+class TestSURELoss(unittest.TestCase):
+ def test_real_value(self):
+ """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0."""
+ sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1)
+
+ def operator(x):
+ return x
+
+ y_pseudo_gt = torch.randn(2, 1, 128, 128)
+ x = torch.randn(2, 1, 128, 128)
+ loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False)
+ self.assertAlmostEqual(loss.item(), 0.0)
+
+ def test_complex_value(self):
+ """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0."""
+
+ def operator(x):
+ return x
+
+ sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1)
+ y_pseudo_gt = torch.randn(2, 2, 128, 128)
+ x = torch.randn(2, 2, 128, 128)
+ loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True)
+ self.assertAlmostEqual(loss.item(), 0.0)
+
+ def test_complex_general_input(self):
+ """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0."""
+
+ def operator(x):
+ return x
+
+ perturb_noise_real = torch.randn(2, 1, 128, 128)
+ perturb_noise_complex = torch.zeros(2, 2, 128, 128)
+ perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze()
+ y_pseudo_gt_real = torch.randn(2, 1, 128, 128)
+ y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128)
+ y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze()
+ x_real = torch.randn(2, 1, 128, 128)
+ x_complex = torch.zeros(2, 2, 128, 128)
+ x_complex[:, 0, :, :] = x_real.squeeze()
+
+ sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1)
+ sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1)
+
+ loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False)
+ loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True)
+ self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6)
+
+
+if __name__ == "__main__":
+ unittest.main()
From 7d48f9ed25c6ba56d5e570bae51d92a3c06be1ce Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Sun, 24 Mar 2024 20:29:37 +0100
Subject: [PATCH 83/88] fixed code format checks
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 4 ++--
monai/transforms/regularization/dictionary.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index 1d9b5538fb..cd2917bfe9 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -16,12 +16,12 @@
import torch
-from monai.transforms import Randomizable, Transform
+from ..transform import RandomizableTransform
__all__ = ["MixUp", "CutMix", "CutOut"]
-class Mixer(Transform, Randomizable):
+class Mixer(RandomizableTransform):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__()
if alpha <= 0:
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 86a361b8a4..3f7ed90819 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -12,9 +12,9 @@
from __future__ import annotations
from monai.config import KeysCollection
-from monai.transforms import MapTransform
from monai.utils.misc import ensure_tuple
+from ..transform import MapTransform
from .array import CutMix, CutOut, MixUp
__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]
From 292e84dbc496c5f0342d461d577fa7f07e034fcb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Wed, 8 Nov 2023 22:11:53 +0100
Subject: [PATCH 84/88] added feedback suggestions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 43 ++++++++++++++++---
monai/transforms/regularization/dictionary.py | 25 ++++-------
2 files changed, 44 insertions(+), 24 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index cd2917bfe9..f0c91051c0 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -18,11 +18,23 @@
from ..transform import RandomizableTransform
-__all__ = ["MixUp", "CutMix", "CutOut"]
+__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]
class Mixer(RandomizableTransform):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
+ """
+ Mixer is a base class providing the basic logic for the mixup-class of
+ augmentations. In all cases, we need to sample the mixing weights for each
+ sample (lambda in the notation used in the papers). Also, pairs of samples
+ being mixed are picked by randomly shuffling the batch samples.
+
+ Args:
+ batch_size (int): number of samples per batch. That is, samples are expected tp
+ be of size batchsize x channels [x depth] x height x width.
+ alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha)
+ distribution. Defaults to 1.0, the uniform distribution.
+ """
super().__init__()
if alpha <= 0:
raise ValueError(f"Expected positive number, but got {alpha = }")
@@ -50,10 +62,10 @@ class MixUp(Mixer):
"""MixUp as described in:
Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
mixup: Beyond Empirical Risk Minimization, ICLR 2018
- """
- def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
- super().__init__(batch_size, alpha)
+ Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
+ documentation for details on the constructor parameters.
+ """
def apply(self, data: torch.Tensor):
weight, perm = self._params
@@ -79,10 +91,22 @@ class CutMix(Mixer):
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019
- """
- def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
- super().__init__(batch_size, alpha)
+ Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
+ documentation for details on the constructor parameters. Here, alpha not only determines
+ the mixing weight but also the size of the random rectangles used during for mixing.
+ Please refer to the paper for details.
+
+ The most common use case is something close to:
+
+ cm = CutMix(batch_size=8, alpha=0.5)
+ for batch in loader:
+ images, labels = batch
+ augimg, auglabels = cm(images, labels)
+ output = model(augimg)
+ loss = loss_function(output, auglabels)
+ ...
+ """
def apply(self, data: torch.Tensor):
weights, perm = self._params
@@ -119,6 +143,11 @@ class CutOut(Mixer):
Terrance DeVries, Graham W. Taylor.
Improved Regularization of Convolutional Neural Networks with Cutout,
arXiv:1708.04552
+
+ Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
+ documentation for details on the constructor parameters. Here, alpha not only determines
+ the mixing weight but also the size of the random rectangles being cut put.
+ Please refer to the paper for details.
"""
def apply(self, data: torch.Tensor):
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 3f7ed90819..373913da99 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -21,9 +21,8 @@
class MixUpd(MapTransform):
- """MixUp as described in:
- Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
- mixup: Beyond Empirical Risk Minimization, ICLR 2018
+ """
+ Dictionary-based version :py:class:`monai.transforms.MixUp`.
Notice that the mixup transformation will be the same for all entries
for consistency, i.e. images and labels must be applied the same augmenation.
@@ -43,14 +42,9 @@ def __call__(self, data):
return result
-MixUpD = MixUpDict = MixUpd
-
-
class CutMixd(MapTransform):
- """CutMix augmentation as described in:
- Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
- CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
- ICCV 2019
+ """
+ Dictionary-based version :py:class:`monai.transforms.CutMix`.
Notice that the mixture weights will be the same for all entries
for consistency, i.e. images and labels must be aggregated with the same weights,
@@ -79,14 +73,9 @@ def __call__(self, data):
return result
-CutMixD = CutMixDict = CutMixd
-
-
class CutOutd(MapTransform):
- """Cutout as described in the paper:
- Terrance DeVries, Graham W. Taylor
- Improved Regularization of Convolutional Neural Networks with Cutout
- arXiv:1708.04552
+ """
+ Dictionary-based version :py:class:`monai.transforms.CutOut`.
Notice that the cutout is different for every entry in the dictionary.
"""
@@ -103,4 +92,6 @@ def __call__(self, data):
return result
+MixUpD = MixUpDict = MixUpd
+CutMixD = CutMixDict = CutMixd
CutOutD = CutOutDict = CutOutd
From 2606758329e7481bfa1e52bd8b955e01c09d7dd8 Mon Sep 17 00:00:00 2001
From: monai-bot <64792179+monai-bot@users.noreply.github.com>
Date: Mon, 25 Mar 2024 07:26:43 +0000
Subject: [PATCH 85/88] auto updates (#7577)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: monai-bot
Signed-off-by: monai-bot
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
tests/test_conjugate_gradient.py | 1 +
tests/test_sure_loss.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py
index 239dbe3ecd..64efe3b168 100644
--- a/tests/test_conjugate_gradient.py
+++ b/tests/test_conjugate_gradient.py
@@ -19,6 +19,7 @@
class TestConjugateGradient(unittest.TestCase):
+
def test_real_valued_inverse(self):
"""Test ConjugateGradient with real-valued input: when the input is real
value, the output should be the inverse of the matrix."""
diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py
index 945da657bf..903f9bd2ca 100644
--- a/tests/test_sure_loss.py
+++ b/tests/test_sure_loss.py
@@ -19,6 +19,7 @@
class TestSURELoss(unittest.TestCase):
+
def test_real_value(self):
"""Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0."""
sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1)
From c9a6521eebb39ba32aeb18f6a36b9634cc0f33ae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Mon, 25 Mar 2024 13:03:02 +0100
Subject: [PATCH 86/88] flake8 warnings
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/dictionary.py | 4 +---
tests/test_regularization.py | 6 +++---
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 373913da99..5b67b43d93 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -28,9 +28,7 @@ class MixUpd(MapTransform):
for consistency, i.e. images and labels must be applied the same augmenation.
"""
- def __init__(
- self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
- ) -> None:
+ def __init__(self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index 8b974e392d..d381ea72ca 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -26,7 +26,7 @@ def test_mixup(self):
mixup = MixUp(6, 1.0)
output = mixup(sample)
self.assertEqual(output.shape, sample.shape)
- self.assertTrue(any([not torch.allclose(sample, mixup(sample)) for _ in range(10)]))
+ self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10)))
with self.assertRaises(ValueError):
MixUp(6, -0.5)
@@ -59,7 +59,7 @@ def test_cutmix(self):
cutmix = CutMix(6, 1.0)
output = cutmix(sample)
self.assertEqual(output.shape, sample.shape)
- self.assertTrue(any([not torch.allclose(sample, cutmix(sample)) for _ in range(10)]))
+ self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))
def test_cutmixd(self):
for dims in [2, 3]:
@@ -83,7 +83,7 @@ def test_cutout(self):
cutout = CutOut(6, 1.0)
output = cutout(sample)
self.assertEqual(output.shape, sample.shape)
- self.assertTrue(any([not torch.allclose(sample, cutout(sample)) for _ in range(10)]))
+ self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10)))
if __name__ == "__main__":
From af42d65f759228d623c036ca03cf0cce8a12391c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Mon, 25 Mar 2024 20:53:33 +0100
Subject: [PATCH 87/88] =?UTF-8?q?DCO=20Remediation=20Commit=20for=20Juan?=
=?UTF-8?q?=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez=20?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
I, Juan Pablo de la Cruz Gutiérrez , hereby add my Signed-off-by to this commit: b899421f4bc2b33b8bf6b76d898eaf95df59e853
I, Juan Pablo de la Cruz Gutiérrez , hereby add my Signed-off-by to this commit: aaa640fdf0c0a1348678b91215930043da3f317e
I, Juan Pablo de la Cruz Gutiérrez , hereby add my Signed-off-by to this commit: c85976f190c0f1ce6ceb50405c41335e0413127f
I, Juan Pablo de la Cruz Gutiérrez , hereby add my Signed-off-by to this commit: f2fe14e8550974a0026857d4bb2c337a653567d6
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 2 +-
monai/transforms/regularization/dictionary.py | 4 +++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index f0c91051c0..24903ff3ad 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -42,7 +42,7 @@ def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
self.batch_size = batch_size
@abstractmethod
- def apply(cls, data: torch.Tensor):
+ def apply(self, data: torch.Tensor):
raise NotImplementedError()
def randomize(self, data=None) -> None:
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 5b67b43d93..373913da99 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -28,7 +28,9 @@ class MixUpd(MapTransform):
for consistency, i.e. images and labels must be applied the same augmenation.
"""
- def __init__(self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False) -> None:
+ def __init__(
+ self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
+ ) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)
From 859852e214d070bbf8cdbaf58fcbe9bda29d1583 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?=
Date: Mon, 25 Mar 2024 21:09:21 +0100
Subject: [PATCH 88/88] finally got sphinx format right
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: Juan Pablo de la Cruz Gutiérrez
---
monai/transforms/regularization/array.py | 19 +++++++++++--------
1 file changed, 11 insertions(+), 8 deletions(-)
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index 24903ff3ad..6c9022d647 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -88,16 +88,18 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
class CutMix(Mixer):
"""CutMix augmentation as described in:
- Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
- CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
- ICCV 2019
+ Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
+ ICCV 2019
- Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
- documentation for details on the constructor parameters. Here, alpha not only determines
- the mixing weight but also the size of the random rectangles used during for mixing.
- Please refer to the paper for details.
+ Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
+ documentation for details on the constructor parameters. Here, alpha not only determines
+ the mixing weight but also the size of the random rectangles used during for mixing.
+ Please refer to the paper for details.
- The most common use case is something close to:
+ The most common use case is something close to:
+
+ .. code-block:: python
cm = CutMix(batch_size=8, alpha=0.5)
for batch in loader:
@@ -106,6 +108,7 @@ class CutMix(Mixer):
output = model(augimg)
loss = loss_function(output, auglabels)
...
+
"""
def apply(self, data: torch.Tensor):