From 04899b65a8a4644f0205e9399d29565800bc04d2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 6 Sep 2021 18:00:07 +0800 Subject: [PATCH 1/4] [DLMED] fix type issue Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 4 +++- monai/utils/type_conversion.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index dd045817fb..9b4f056196 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -323,7 +323,9 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. """ if not isinstance(img, (torch.Tensor, np.ndarray)): raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") - img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) + output_type = torch.Tensor if isinstance(img, torch.Tensor) else np.ndarray + img_out, *_ = convert_data_type(img, output_type=output_type, dtype=dtype or self.dtype) + return img_out diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 14300eeca0..e222cb8ede 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -171,7 +171,7 @@ def convert_data_type( Returns: modified data, orig_type, orig_device """ - orig_type = type(data) + orig_type = torch.Tensor if isinstance(data, torch.Tensor) else np.ndarray orig_device = data.device if isinstance(data, torch.Tensor) else None output_type = output_type or orig_type @@ -205,4 +205,5 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor device = None if isinstance(dst, torch.Tensor): device = dst.device - return convert_data_type(data=src, output_type=type(dst), device=device, dtype=dst.dtype) + output_type = torch.Tensor if isinstance(dst, torch.Tensor) else np.ndarray + return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype) From a252b0d266a49acc5c78778d6da4db7736c62293 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 6 Sep 2021 18:55:37 +0800 Subject: [PATCH 2/4] [DLMED] fix test Signed-off-by: Nic Ma --- monai/utils/type_conversion.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e222cb8ede..6256f9c215 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -171,7 +171,13 @@ def convert_data_type( Returns: modified data, orig_type, orig_device """ - orig_type = torch.Tensor if isinstance(data, torch.Tensor) else np.ndarray + if isinstance(data, torch.Tensor): + orig_type = torch.Tensor + elif isinstance(data, np.ndarray): + orig_type = np.ndarray + else: + orig_type = type(data) + orig_device = data.device if isinstance(data, torch.Tensor) else None output_type = output_type or orig_type From beb1c7654ea3d8e1aa64eb9127217ce2e21dff8d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 6 Sep 2021 19:04:15 +0800 Subject: [PATCH 3/4] [DLMED] simplify the change Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 4 +--- monai/utils/type_conversion.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9b4f056196..dd045817fb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -323,9 +323,7 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. """ if not isinstance(img, (torch.Tensor, np.ndarray)): raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") - output_type = torch.Tensor if isinstance(img, torch.Tensor) else np.ndarray - img_out, *_ = convert_data_type(img, output_type=output_type, dtype=dtype or self.dtype) - + img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) return img_out diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 6256f9c215..776a684ed6 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -211,5 +211,4 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor device = None if isinstance(dst, torch.Tensor): device = dst.device - output_type = torch.Tensor if isinstance(dst, torch.Tensor) else np.ndarray - return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype) + return convert_data_type(data=src, output_type=type(dst), device=device, dtype=dst.dtype) From c481122a21cb27b965f386c79f8285bf943376b5 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 6 Sep 2021 22:36:11 +0800 Subject: [PATCH 4/4] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/utils/type_conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 776a684ed6..b0ce187e38 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -171,6 +171,7 @@ def convert_data_type( Returns: modified data, orig_type, orig_device """ + orig_type: Any if isinstance(data, torch.Tensor): orig_type = torch.Tensor elif isinstance(data, np.ndarray):