diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ddf6d9e563..438b195ce7 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -485,16 +485,18 @@ class NumpyReader(ImageReader): Args: npz_keys: if loading npz file, only load the specified keys, if None, load all the items. stack the loaded items together to construct a new first dimension. + channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: https://numpy.org/doc/stable/reference/generated/numpy.load.html """ - def __init__(self, npz_keys: Optional[KeysCollection] = None, **kwargs): + def __init__(self, npz_keys: Optional[KeysCollection] = None, channel_dim: Optional[int] = None, **kwargs): super().__init__() if npz_keys is not None: npz_keys = ensure_tuple(npz_keys) self.npz_keys = npz_keys + self.channel_dim = channel_dim self.kwargs = kwargs def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: @@ -558,9 +560,13 @@ def get_data(self, img): for i in ensure_tuple(img): header = {} if isinstance(i, np.ndarray): - # can not detect the channel dim of numpy array, use all the dims as spatial_shape - header["spatial_shape"] = i.shape + # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape + spatial_shape = np.asarray(i.shape) + if isinstance(self.channel_dim, int): + spatial_shape = np.delete(spatial_shape, self.channel_dim) + header["spatial_shape"] = spatial_shape img_array.append(i) + header["original_channel_dim"] = self.channel_dim if isinstance(self.channel_dim, int) else "no_channel" _copy_compatible_dict(header, compatible_meta) return _stack_images(img_array, compatible_meta), compatible_meta @@ -753,7 +759,7 @@ def get_data( region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) metadata: Dict = {} - metadata["spatial_shape"] = region.shape[:-1] + metadata["spatial_shape"] = np.asarray(region.shape[:-1]) metadata["original_channel_dim"] = -1 region = EnsureChannelFirst()(region, metadata) if patch_size is None: diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index a57a036905..d84f339e3d 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -10,12 +10,15 @@ # limitations under the License. import os +import sys import tempfile import unittest import numpy as np +import torch -from monai.data import NumpyReader +from monai.data import DataLoader, Dataset, NumpyReader +from monai.transforms import LoadImaged class TestNumpyReader(unittest.TestCase): @@ -27,8 +30,8 @@ def test_npy(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) - self.assertTupleEqual(result[0].shape, test_data.shape) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape) + np.testing.assert_allclose(result[0].shape, test_data.shape) np.testing.assert_allclose(result[0], test_data) def test_npz1(self): @@ -39,8 +42,8 @@ def test_npz1(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, test_data1.shape) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, test_data1.shape) np.testing.assert_allclose(result[0], test_data1) def test_npz2(self): @@ -52,8 +55,8 @@ def test_npz2(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) def test_npz3(self): @@ -65,8 +68,8 @@ def test_npz3(self): reader = NumpyReader(npz_keys=["test1", "test2"]) result = reader.get_data(reader.read(filepath)) - self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) - self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape) + np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) def test_npy_pickle(self): @@ -77,7 +80,7 @@ def test_npy_pickle(self): reader = NumpyReader() result = reader.get_data(reader.read(filepath))[0].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"].shape, test_data["test"].shape) np.testing.assert_allclose(result["test"], test_data["test"]) def test_kwargs(self): @@ -88,7 +91,39 @@ def test_kwargs(self): reader = NumpyReader(mmap_mode="r") result = reader.get_data(reader.read(filepath, mmap_mode=None))[0].item() - self.assertTupleEqual(result["test"].shape, test_data["test"].shape) + np.testing.assert_allclose(result["test"].shape, test_data["test"].shape) + + def test_dataloader(self): + test_data = np.random.randint(0, 256, size=[3, 4, 5]) + datalist = [] + with tempfile.TemporaryDirectory() as tempdir: + for i in range(4): + filepath = os.path.join(tempdir, f"test_data{i}.npz") + np.savez(filepath, test_data) + datalist.append({"image": filepath}) + + num_workers = 2 if sys.platform == "linux" else 0 + loader = DataLoader( + Dataset(data=datalist, transform=LoadImaged(keys="image", reader=NumpyReader())), + batch_size=2, + num_workers=num_workers, + ) + for d in loader: + for s in d["image_meta_dict"]["spatial_shape"]: + torch.testing.assert_allclose(s, torch.as_tensor([3, 4, 5])) + for c in d["image"]: + torch.testing.assert_allclose(c, test_data) + + def test_channel_dim(self): + test_data = np.random.randint(0, 256, size=[3, 4, 5, 2]) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data) + + reader = NumpyReader(channel_dim=-1) + result = reader.get_data(reader.read(filepath)) + np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape[:-1]) + self.assertEqual(result[1]["original_channel_dim"], -1) if __name__ == "__main__": diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 7cd9efbf06..b86c84cce5 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -14,6 +14,7 @@ from unittest import skipUnless import numpy as np +import torch from numpy.testing import assert_array_equal from parameterized import parameterized @@ -151,8 +152,8 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte dataset = Dataset([{"image": file_path}], transform=train_transform) data_loader = DataLoader(dataset) data: dict = first(data_loader) - spatial_shape = tuple(d.item() for d in data["image_meta_dict"]["spatial_shape"]) - self.assertTupleEqual(spatial_shape, expected_spatial_shape) + for s in data["image_meta_dict"]["spatial_shape"]: + torch.testing.assert_allclose(s, expected_spatial_shape) self.assertTupleEqual(data["image"].shape, expected_shape)