diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 73fbd6666f..57d43db5d0 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -330,8 +330,6 @@ def __call__(self, img: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch. TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ - if not isinstance(img, (torch.Tensor, np.ndarray)): - raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype) return img_out diff --git a/tests/test_cast_to_type.py b/tests/test_cast_to_type.py index 0ef25cbafa..d06efb17b5 100644 --- a/tests/test_cast_to_type.py +++ b/tests/test_cast_to_type.py @@ -16,14 +16,23 @@ from parameterized import parameterized from monai.transforms import CastToType +from monai.utils import optional_import from monai.utils.type_conversion import get_equivalent_dtype from tests.utils import TEST_NDARRAYS +cp, has_cp = optional_import("cupy") + TESTS = [] for p in TEST_NDARRAYS: for out_dtype in (np.float64, torch.float64): TESTS.append([out_dtype, p(np.array([[0, 1], [1, 2]], dtype=np.float32)), out_dtype]) +TESTS_CUPY = [ + [np.float32, np.array([[0, 1], [1, 2]], dtype=np.float32), np.float32], + [np.float32, np.array([[0, 1], [1, 2]], dtype=np.uint8), np.float32], + [np.uint8, np.array([[0, 1], [1, 2]], dtype=np.float32), np.uint8], +] + class TestCastToType(unittest.TestCase): @parameterized.expand(TESTS) @@ -35,6 +44,19 @@ def test_type(self, out_dtype, input_data, expected_type): result = CastToType()(input_data, out_dtype) self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + @parameterized.expand(TESTS_CUPY) + @unittest.skipUnless(has_cp, "Requires CuPy") + def test_type_cupy(self, out_dtype, input_data, expected_type): + input_data = cp.asarray(input_data) + + result = CastToType(dtype=out_dtype)(input_data) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + + result = CastToType()(input_data, out_dtype) + self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result))) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py index be495564fb..342a677ce1 100644 --- a/tests/test_cast_to_typed.py +++ b/tests/test_cast_to_typed.py @@ -16,6 +16,9 @@ from parameterized import parameterized from monai.transforms import CastToTyped +from monai.utils import optional_import + +cp, has_cp = optional_import("cupy") TEST_CASE_1 = [ {"keys": ["img"], "dtype": np.float64}, @@ -33,6 +36,26 @@ ] +TESTS_CUPY = [ + [ + {"keys": "image", "dtype": np.uint8}, + { + "image": np.array([[0, 1], [1, 2]], dtype=np.float32), + "label": np.array([[0, 1], [1, 1]], dtype=np.float32), + }, + {"image": np.uint8, "label": np.float32}, + ], + [ + {"keys": ["image", "label"], "dtype": np.float32}, + { + "image": np.array([[0, 1], [1, 2]], dtype=np.uint8), + "label": np.array([[0, 1], [1, 1]], dtype=np.uint8), + }, + {"image": np.float32, "label": np.float32}, + ], +] + + class TestCastToTyped(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type(self, input_param, input_data, expected_type): @@ -40,6 +63,16 @@ def test_type(self, input_param, input_data, expected_type): for k, v in result.items(): self.assertEqual(v.dtype, expected_type[k]) + @parameterized.expand(TESTS_CUPY) + @unittest.skipUnless(has_cp, "Requires CuPy") + def test_type_cupy(self, input_param, input_data, expected_type): + input_data = {k: cp.asarray(v) for k, v in input_data.items()} + + result = CastToTyped(**input_param)(input_data) + for k, v in result.items(): + self.assertTrue(isinstance(v, cp.ndarray)) + self.assertEqual(v.dtype, expected_type[k]) + if __name__ == "__main__": unittest.main()