From f5a7da8bbe676f2efa67254f13cdcdd76eede26a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:36:18 +0800 Subject: [PATCH 1/4] fix #8040 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 2 ++ tests/test_utils_pytorch_numpy_unification.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 020d99af16..c15bdf387a 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -546,6 +546,8 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore + if isinstance(ret, tuple): + return ret.values return ret diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 6e655289e4..691fc7e9fb 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import mode, percentile +from monai.transforms.utils_pytorch_numpy_unification import mode, percentile, min from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick @@ -27,6 +27,11 @@ TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) +TEST_MIN = [] +for p in TEST_NDARRAYS: + TEST_MIN.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, p(1)]) + TEST_MIN.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, p([3.1, 3. ])]) + class TestPytorchNumpyUnification(unittest.TestCase): @@ -73,6 +78,11 @@ def test_dim(self): def test_mode(self, array, expected, to_long): res = mode(array, to_long=to_long) assert_allclose(res, expected) + + @parameterized.expand(TEST_MIN) + def test_min(self, array, input_params, expected): + res = min(array, **input_params) + assert_allclose(res, expected, type_test=False) if __name__ == "__main__": From 6d57d5dfaa138d0bb2a137f58c143aeee3e3782d Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:39:36 +0800 Subject: [PATCH 2/4] add max Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .../utils_pytorch_numpy_unification.py | 2 ++ tests/test_utils_pytorch_numpy_unification.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index c15bdf387a..05e8f86e9b 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,6 +480,8 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore + if isinstance(ret, tuple): + return ret.values return ret diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 691fc7e9fb..df935b465d 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import mode, percentile, min +from monai.transforms.utils_pytorch_numpy_unification import min, max, mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick @@ -27,10 +27,12 @@ TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) -TEST_MIN = [] +TEST_MIN_MAX = [] for p in TEST_NDARRAYS: - TEST_MIN.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, p(1)]) - TEST_MIN.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, p([3.1, 3. ])]) + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])]) + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])]) class TestPytorchNumpyUnification(unittest.TestCase): @@ -78,10 +80,10 @@ def test_dim(self): def test_mode(self, array, expected, to_long): res = mode(array, to_long=to_long) assert_allclose(res, expected) - - @parameterized.expand(TEST_MIN) - def test_min(self, array, input_params, expected): - res = min(array, **input_params) + + @parameterized.expand(TEST_MIN_MAX) + def test_min_max(self, array, input_params, func, expected): + res = func(array, **input_params) assert_allclose(res, expected, type_test=False) From 5e27890db2c530f782d68d00fa3d9452360972a6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:02:56 +0800 Subject: [PATCH 3/4] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 4 ++-- tests/test_utils_pytorch_numpy_unification.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 05e8f86e9b..09fc6f9e5c 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore - if isinstance(ret, tuple): + if isinstance(ret, (torch.return_types.max, torch.return_types.min)): return ret.values return ret @@ -548,7 +548,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore - if isinstance(ret, tuple): + if isinstance(ret, (torch.return_types.max, torch.return_types.min)): return ret.values return ret diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index df935b465d..90c0401e46 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import min, max, mode, percentile +from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick From 7b444f1d557bb5a164935ac93fb45bdb442ea65a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:22:29 +0800 Subject: [PATCH 4/4] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/utils_pytorch_numpy_unification.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 09fc6f9e5c..98b75cff76 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,9 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore - if isinstance(ret, (torch.return_types.max, torch.return_types.min)): - return ret.values - return ret + return ret[0] if isinstance(ret, tuple) else ret def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: @@ -548,9 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore - if isinstance(ret, (torch.return_types.max, torch.return_types.min)): - return ret.values - return ret + return ret[0] if isinstance(ret, tuple) else ret def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: