Skip to content
Merged
14 changes: 10 additions & 4 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,18 @@ class NumpyReader(ImageReader):
Args:
npz_keys: if loading npz file, only load the specified keys, if None, load all the items.
stack the loaded items together to construct a new first dimension.
channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel.
kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args:
https://numpy.org/doc/stable/reference/generated/numpy.load.html

"""

def __init__(self, npz_keys: Optional[KeysCollection] = None, **kwargs):
def __init__(self, npz_keys: Optional[KeysCollection] = None, channel_dim: Optional[int] = None, **kwargs):
super().__init__()
if npz_keys is not None:
npz_keys = ensure_tuple(npz_keys)
self.npz_keys = npz_keys
self.channel_dim = channel_dim
self.kwargs = kwargs

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
Expand Down Expand Up @@ -558,9 +560,13 @@ def get_data(self, img):
for i in ensure_tuple(img):
header = {}
if isinstance(i, np.ndarray):
# can not detect the channel dim of numpy array, use all the dims as spatial_shape
header["spatial_shape"] = i.shape
# if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape
spatial_shape = np.asarray(i.shape)
if isinstance(self.channel_dim, int):
spatial_shape = np.delete(spatial_shape, self.channel_dim)
header["spatial_shape"] = spatial_shape
img_array.append(i)
header["original_channel_dim"] = self.channel_dim if isinstance(self.channel_dim, int) else "no_channel"
_copy_compatible_dict(header, compatible_meta)

return _stack_images(img_array, compatible_meta), compatible_meta
Expand Down Expand Up @@ -753,7 +759,7 @@ def get_data(
region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)

metadata: Dict = {}
metadata["spatial_shape"] = region.shape[:-1]
metadata["spatial_shape"] = np.asarray(region.shape[:-1])
metadata["original_channel_dim"] = -1
region = EnsureChannelFirst()(region, metadata)
if patch_size is None:
Expand Down
57 changes: 46 additions & 11 deletions tests/test_numpy_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
# limitations under the License.

import os
import sys
import tempfile
import unittest

import numpy as np
import torch

from monai.data import NumpyReader
from monai.data import DataLoader, Dataset, NumpyReader
from monai.transforms import LoadImaged


class TestNumpyReader(unittest.TestCase):
Expand All @@ -27,8 +30,8 @@ def test_npy(self):

reader = NumpyReader()
result = reader.get_data(reader.read(filepath))
self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape)
self.assertTupleEqual(result[0].shape, test_data.shape)
np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape)
np.testing.assert_allclose(result[0].shape, test_data.shape)
np.testing.assert_allclose(result[0], test_data)

def test_npz1(self):
Expand All @@ -39,8 +42,8 @@ def test_npz1(self):

reader = NumpyReader()
result = reader.get_data(reader.read(filepath))
self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape)
self.assertTupleEqual(result[0].shape, test_data1.shape)
np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape)
np.testing.assert_allclose(result[0].shape, test_data1.shape)
np.testing.assert_allclose(result[0], test_data1)

def test_npz2(self):
Expand All @@ -52,8 +55,8 @@ def test_npz2(self):

reader = NumpyReader()
result = reader.get_data(reader.read(filepath))
self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape)
self.assertTupleEqual(result[0].shape, (2, 3, 4, 4))
np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape)
np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4))
np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2]))

def test_npz3(self):
Expand All @@ -65,8 +68,8 @@ def test_npz3(self):

reader = NumpyReader(npz_keys=["test1", "test2"])
result = reader.get_data(reader.read(filepath))
self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape)
self.assertTupleEqual(result[0].shape, (2, 3, 4, 4))
np.testing.assert_allclose(result[1]["spatial_shape"], test_data1.shape)
np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4))
np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2]))

def test_npy_pickle(self):
Expand All @@ -77,7 +80,7 @@ def test_npy_pickle(self):

reader = NumpyReader()
result = reader.get_data(reader.read(filepath))[0].item()
self.assertTupleEqual(result["test"].shape, test_data["test"].shape)
np.testing.assert_allclose(result["test"].shape, test_data["test"].shape)
np.testing.assert_allclose(result["test"], test_data["test"])

def test_kwargs(self):
Expand All @@ -88,7 +91,39 @@ def test_kwargs(self):

reader = NumpyReader(mmap_mode="r")
result = reader.get_data(reader.read(filepath, mmap_mode=None))[0].item()
self.assertTupleEqual(result["test"].shape, test_data["test"].shape)
np.testing.assert_allclose(result["test"].shape, test_data["test"].shape)

def test_dataloader(self):
test_data = np.random.randint(0, 256, size=[3, 4, 5])
datalist = []
with tempfile.TemporaryDirectory() as tempdir:
for i in range(4):
filepath = os.path.join(tempdir, f"test_data{i}.npz")
np.savez(filepath, test_data)
datalist.append({"image": filepath})

num_workers = 2 if sys.platform == "linux" else 0
loader = DataLoader(
Dataset(data=datalist, transform=LoadImaged(keys="image", reader=NumpyReader())),
batch_size=2,
num_workers=num_workers,
)
for d in loader:
for s in d["image_meta_dict"]["spatial_shape"]:
torch.testing.assert_allclose(s, torch.as_tensor([3, 4, 5]))
for c in d["image"]:
torch.testing.assert_allclose(c, test_data)

def test_channel_dim(self):
test_data = np.random.randint(0, 256, size=[3, 4, 5, 2])
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "test_data.npy")
np.save(filepath, test_data)

reader = NumpyReader(channel_dim=-1)
result = reader.get_data(reader.read(filepath))
np.testing.assert_allclose(result[1]["spatial_shape"], test_data.shape[:-1])
self.assertEqual(result[1]["original_channel_dim"], -1)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions tests/test_wsireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from unittest import skipUnless

import numpy as np
import torch
from numpy.testing import assert_array_equal
from parameterized import parameterized

Expand Down Expand Up @@ -151,8 +152,8 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte
dataset = Dataset([{"image": file_path}], transform=train_transform)
data_loader = DataLoader(dataset)
data: dict = first(data_loader)
spatial_shape = tuple(d.item() for d in data["image_meta_dict"]["spatial_shape"])
self.assertTupleEqual(spatial_shape, expected_spatial_shape)
for s in data["image_meta_dict"]["spatial_shape"]:
torch.testing.assert_allclose(s, expected_spatial_shape)
self.assertTupleEqual(data["image"].shape, expected_shape)


Expand Down