diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 3c1fc0abed..29342742cb 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -669,7 +669,7 @@ class CacheDataset(Dataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], + transform: Optional[Union[Sequence[Callable], Callable]] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: Optional[int] = 1, @@ -856,7 +856,7 @@ class SmartCacheDataset(Randomizable, CacheDataset): Args: data: input data to load and transform to generate dataset for model. transform: transforms to execute operations on input data. - replace_rate: percentage of the cached items to be replaced in every epoch. + replace_rate: percentage of the cached items to be replaced in every epoch (default to 0.1). cache_num: number of items to be cached. Default is `sys.maxsize`. will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate: percentage of cached data in total, default is 1.0 (cache all). @@ -883,8 +883,8 @@ class SmartCacheDataset(Randomizable, CacheDataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], - replace_rate: float, + transform: Optional[Union[Sequence[Callable], Callable]] = None, + replace_rate: float = 0.1, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_init_workers: Optional[int] = 1, diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 4b77d4a55a..4fa1b5ea69 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -55,6 +55,12 @@ def test_shape(self, transform, expected_shape): data4 = dataset[-1] self.assertEqual(len(data3), 1) + if transform is None: + # Check without providing transfrom + dataset2 = CacheDataset(data=test_data, cache_rate=0.5, as_contiguous=True) + for k in ["image", "label", "extra"]: + self.assertEqual(dataset[0][k], dataset2[0][k]) + if transform is None: self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index e7d51be63a..6eca6113f0 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -56,6 +56,17 @@ def test_shape(self, replace_rate, num_replace_workers, transform): num_init_workers=4, num_replace_workers=num_replace_workers, ) + if transform is None: + # Check without providing transfrom + dataset2 = SmartCacheDataset( + data=test_data, + replace_rate=replace_rate, + cache_num=16, + num_init_workers=4, + num_replace_workers=num_replace_workers, + ) + for k in ["image", "label", "extra"]: + self.assertEqual(dataset[0][k], dataset2[0][k]) self.assertEqual(len(dataset._cache), dataset.cache_num) for i in range(dataset.cache_num): diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 3655100dab..5d092c4ce5 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -84,7 +84,9 @@ TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW -TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel +TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel +TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color @@ -106,20 +108,6 @@ def save_rgba_tiff(array: np.ndarray, filename: str, mode: str): return filename -def save_gray_tiff(array: np.ndarray, filename: str): - """ - Save numpy array into a TIFF file - - Args: - array: numpy ndarray with any shape - filename: the filename to be used for the tiff file. - """ - img_gray = array - imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack") - - return filename - - @skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") def setUpModule(): # noqa: N802 hash_type = testing_data_config("images", FILE_KEY, "hash_type") @@ -187,13 +175,15 @@ def test_read_rgba(self, img_expected): self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) - @parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D]) + @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D]) @skipUnless(has_tiff, "Requires tifffile.") def test_read_malformats(self, img_expected): + if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1): + # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230 + return reader = WSIReader(self.backend) - file_path = save_gray_tiff( - img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") - ) + file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") + imwrite(file_path, img_expected, shape=img_expected.shape) with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)): with reader.read(file_path) as img_obj: reader.get_data(img_obj) diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 63d61dfeb3..2ac4125f97 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -72,7 +72,9 @@ TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW -TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel +TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel +TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color @@ -103,7 +105,7 @@ def save_gray_tiff(array: np.ndarray, filename: str): filename: the filename to be used for the tiff file. """ img_gray = array - imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack") + imwrite(filename, img_gray, shape=img_gray.shape) return filename @@ -180,13 +182,15 @@ def test_read_rgba(self, img_expected): self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) - @parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D]) + @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D]) @skipUnless(has_tiff, "Requires tifffile.") def test_read_malformats(self, img_expected): + if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1): + # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230 + return reader = WSIReader(self.backend) - file_path = save_gray_tiff( - img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") - ) + file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") + imwrite(file_path, img_expected, shape=img_expected.shape) with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)): with reader.read(file_path) as img_obj: reader.get_data(img_obj)