Skip to content
27 changes: 7 additions & 20 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -93,12 +93,7 @@ def __init__(
self.squeeze_end_dims = squeeze_end_dims
self.data_root_dir = data_root_dir

def save(
self,
data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_index: Optional[int] = None,
) -> None:
def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""
Save data into a Nifti file.
The meta_data could optionally have the following keys:
Expand All @@ -107,6 +102,7 @@ def save(
- ``'original_affine'`` -- for data orientation handling, defaulting to an identity matrix.
- ``'affine'`` -- for data output affine, defaulting to an identity matrix.
- ``'spatial_shape'`` -- for data output shape.
- ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename.

When meta_data is specified, the saver will try to resample batch data from the space
defined by "affine" to the space defined by "original_affine".
Expand All @@ -117,7 +113,6 @@ def save(
data: target data content that to be saved as a NIfTI format file.
Assuming the data shape starts with a channel dimension and followed by spatial dimensions.
meta_data: the meta data information corresponding to the data.
patch_index: if the data is a patch of big image, need to append the patch index to filename.

See Also
:py:meth:`monai.data.nifti_writer.write_nifti`
Expand All @@ -127,6 +122,7 @@ def save(
original_affine = meta_data.get("original_affine", None) if meta_data else None
affine = meta_data.get("affine", None) if meta_data else None
spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None
patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None

if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
Expand Down Expand Up @@ -158,12 +154,7 @@ def save(
output_dtype=self.output_dtype,
)

def save_batch(
self,
batch_data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_indice: Optional[Sequence[int]] = None,
) -> None:
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""
Save a batch of data into Nifti format files.

Expand All @@ -180,11 +171,7 @@ def save_batch(
Args:
batch_data: target batch data content that save into NIfTI format.
meta_data: every key-value in the meta_data is corresponding to a batch of data.
patch_indice: if the data is a patch of big image, need to append the patch index to filename.

"""
for i, data in enumerate(batch_data): # save a batch of files
self.save(
data=data,
meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None,
patch_index=patch_indice[i] if patch_indice is not None else None,
)
self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None)
27 changes: 7 additions & 20 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -71,18 +71,14 @@ def __init__(

self._data_index = 0

def save(
self,
data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_index: Optional[int] = None,
) -> None:
def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""
Save data into a png file.
The meta_data could optionally have the following keys:

- ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
- ``'spatial_shape'`` -- for data output shape.
- ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename.

If meta_data is None, use the default index (starting from 0) as the filename.

Expand All @@ -92,7 +88,6 @@ def save(
Shape of the spatial dimensions (C,H,W).
C should be 1, 3 or 4
meta_data: the meta data information corresponding to the data.
patch_index: if the data is a patch of big image, need to append the patch index to filename.

Raises:
ValueError: When ``data`` channels is not one of [1, 3, 4].
Expand All @@ -104,6 +99,7 @@ def save(
filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None
patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None

if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
Expand All @@ -126,22 +122,13 @@ def save(
scale=self.scale,
)

def save_batch(
self,
batch_data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_indice: Optional[Sequence[int]] = None,
) -> None:
def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""Save a batch of data into png format files.

Args:
batch_data: target batch data content that save into png format.
meta_data: every key-value in the meta_data is corresponding to a batch of data.
patch_indice: if the data is a patch of big image, need to append the patch index to filename.

"""
for i, data in enumerate(batch_data): # save a batch of files
self.save(
data=data,
meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None,
patch_index=patch_indice[i] if patch_indice is not None else None,
)
self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None)
7 changes: 2 additions & 5 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

from monai.config import DtypeLike
from monai.transforms import SaveImage
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, exact_version, optional_import
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
Expand Down Expand Up @@ -145,6 +143,5 @@ def __call__(self, engine: Engine) -> None:
"""
meta_data = self.batch_transform(engine.state.batch)
engine_output = self.output_transform(engine.state.output)
patch_indice = engine.state.batch.get(Key.PATCH_INDEX, None)
self._saver(engine_output, meta_data, patch_indice)
self._saver(engine_output, meta_data)
self.logger.info("saved all the model outputs into files.")
32 changes: 27 additions & 5 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ class RandSpatialCropSamplesd(RandomizableTransform, MapTransform):
Crop image with random size or specific size ROI to generate a list of N samples.
It can crop at a random position as center or at the image center. And allows to set
the minimum size to limit the randomly generated ROI. Suppose all the expected fields
specified by `keys` have same shape.
specified by `keys` have same shape, and add `patch_index` to the corresponding meta data.
It will return a list of dictionaries for all the cropped images.

Args:
Expand All @@ -495,6 +495,9 @@ class RandSpatialCropSamplesd(RandomizableTransform, MapTransform):
random_center: crop at random position as center or the image center.
random_size: crop with random size or specific size ROI.
The actual size is sampled from `randint(roi_size, img_size)`.
meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data,
default is `meta_dict`, the meta data is a dictionary object.
used to add `patch_index` to the meta dict.
allow_missing_keys: don't raise exception if key is missing.

Raises:
Expand All @@ -509,6 +512,7 @@ def __init__(
num_samples: int,
random_center: bool = True,
random_size: bool = True,
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
) -> None:
RandomizableTransform.__init__(self, prob=1.0, do_transform=True)
Expand All @@ -517,6 +521,7 @@ def __init__(
raise ValueError(f"num_samples must be positive, got {num_samples}.")
self.num_samples = num_samples
self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size, allow_missing_keys)
self.meta_key_postfix = meta_key_postfix

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
Expand All @@ -530,9 +535,15 @@ def randomize(self, data: Optional[Any] = None) -> None:

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
ret = []
d = dict(data)
for i in range(self.num_samples):
cropped = self.cropper(data)
cropped[Key.PATCH_INDEX] = i # type: ignore
cropped = self.cropper(d)
# add `patch_index` to the meta data
for key in self.key_iterator(d):
meta_data_key = f"{key}_{self.meta_key_postfix}"
if meta_data_key not in cropped:
cropped[meta_data_key] = {} # type: ignore
cropped[meta_data_key][Key.PATCH_INDEX] = i
ret.append(cropped)
return ret

Expand Down Expand Up @@ -687,6 +698,8 @@ class RandCropByPosNegLabeld(RandomizableTransform, MapTransform):
Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`.
Crop random fixed sized regions with the center being a foreground or background voxel
based on the Pos Neg Ratio.
Suppose all the expected fields specified by `keys` have same shape,
and add `patch_index` to the corresponding meta data.
And will return a list of dictionaries for all the cropped images.

Args:
Expand All @@ -712,6 +725,9 @@ class RandCropByPosNegLabeld(RandomizableTransform, MapTransform):
`image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key`
and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening.
a typical usage is to call `FgBgToIndicesd` transform first and cache the results.
meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data,
default is `meta_dict`, the meta data is a dictionary object.
used to add `patch_index` to the meta dict.
allow_missing_keys: don't raise exception if key is missing.

Raises:
Expand All @@ -732,6 +748,7 @@ def __init__(
image_threshold: float = 0.0,
fg_indices_key: Optional[str] = None,
bg_indices_key: Optional[str] = None,
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
) -> None:
RandomizableTransform.__init__(self)
Expand All @@ -748,6 +765,7 @@ def __init__(
self.image_threshold = image_threshold
self.fg_indices_key = fg_indices_key
self.bg_indices_key = bg_indices_key
self.meta_key_postfix = meta_key_postfix
self.centers: Optional[List[List[np.ndarray]]] = None

def randomize(
Expand Down Expand Up @@ -789,8 +807,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
# fill in the extra keys with unmodified data
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = data[key]
# add patch index in the meta data
results[i][Key.PATCH_INDEX] = i # type: ignore
# add `patch_index` to the meta data
for key in self.key_iterator(d):
meta_data_key = f"{key}_{self.meta_key_postfix}"
if meta_data_key not in results[i]:
results[i][meta_data_key] = {} # type: ignore
results[i][meta_data_key][Key.PATCH_INDEX] = i

return results

Expand Down
13 changes: 4 additions & 9 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,14 @@ def __init__(

self.save_batch = save_batch

def __call__(
self,
img: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_index=None, # type is Union[Sequence[int], int, None], can't be compatible with save and save_batch
):
def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None):
"""
Args:
img: target data content that save into file.
meta_data: key-value pairs of meta_data corresponding to the data.
patch_index: if the data is a patch of big image, need to append the patch index to filename.

"""
if self.save_batch:
self.saver.save_batch(img, meta_data, patch_index)
self.saver.save_batch(img, meta_data)
else:
self.saver.save(img, meta_data, patch_index)
self.saver.save(img, meta_data)
6 changes: 2 additions & 4 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from monai.data.image_reader import ImageReader
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.transform import MapTransform
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode

__all__ = [
"LoadImaged",
Expand Down Expand Up @@ -229,7 +227,7 @@ def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None
self._saver(img=d[key], meta_data=meta_data, patch_index=d.get(Key.PATCH_INDEX, None))
self._saver(img=d[key], meta_data=meta_data)
return d


Expand Down
2 changes: 1 addition & 1 deletion tests/test_rand_crop_by_pos_neg_labeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_type_shape(self, input_param, input_data, expected_type, expected_shape
self.assertTupleEqual(result[0]["extral"].shape, expected_shape)
self.assertTupleEqual(result[0]["label"].shape, expected_shape)
for i, item in enumerate(result):
self.assertEqual(item["patch_index"], i)
self.assertEqual(item["image_meta_dict"]["patch_index"], i)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion tests/test_rand_spatial_crop_samplesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last):
self.assertTupleEqual(item["img"].shape, expected)
self.assertTupleEqual(item["seg"].shape, expected)
for i, item in enumerate(result):
self.assertEqual(item["patch_index"], i)
self.assertEqual(item["img_meta_dict"]["patch_index"], i)
self.assertEqual(item["seg_meta_dict"]["patch_index"], i)
np.testing.assert_allclose(item["img"], expected_last["img"])
np.testing.assert_allclose(item["seg"], expected_last["seg"])

Expand Down
2 changes: 1 addition & 1 deletion tests/test_save_imaged.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_saved_content(self, test_data, output_ext, resample, save_batch):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext)
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
else:
patch_index = test_data.get("patch_index", None)
patch_index = test_data["img_meta_dict"].get("patch_index", None)
patch_index = f"_{patch_index}" if patch_index is not None else ""
filepath = os.path.join("testfile0", "testfile0" + "_trans" + patch_index + output_ext)
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
Expand Down