Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Union
from typing import Dict, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -93,7 +93,12 @@ def __init__(
self.squeeze_end_dims = squeeze_end_dims
self.data_root_dir = data_root_dir

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save(
self,
data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_index: Optional[int] = None,
) -> None:
"""
Save data into a Nifti file.
The meta_data could optionally have the following keys:
Expand All @@ -112,6 +117,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
data: target data content that to be saved as a NIfTI format file.
Assuming the data shape starts with a channel dimension and followed by spatial dimensions.
meta_data: the meta data information corresponding to the data.
patch_index: if the data is a patch of big image, need to append the patch index to filename.

See Also
:py:meth:`monai.data.nifti_writer.write_nifti`
Expand All @@ -125,8 +131,8 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir, self.data_root_dir)
filename = f"{filename}{self.output_ext}"
path = create_file_basename(self.output_postfix, filename, self.output_dir, self.data_root_dir, patch_index)
path = f"{path}{self.output_ext}"
# change data shape to be (channel, h, w, d)
while len(data.shape) < 4:
data = np.expand_dims(data, -1)
Expand All @@ -140,7 +146,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]

write_nifti(
data,
file_name=filename,
file_name=path,
affine=affine,
target_affine=original_affine,
resample=self.resample,
Expand All @@ -152,7 +158,12 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
output_dtype=self.output_dtype,
)

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save_batch(
self,
batch_data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_indice: Optional[Sequence[int]] = None,
) -> None:
"""
Save a batch of data into Nifti format files.

Expand All @@ -169,6 +180,11 @@ def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Opt
Args:
batch_data: target batch data content that save into NIfTI format.
meta_data: every key-value in the meta_data is corresponding to a batch of data.
patch_indice: if the data is a patch of big image, need to append the patch index to filename.
"""
for i, data in enumerate(batch_data): # save a batch of files
self.save(data, {k: meta_data[k][i] for k in meta_data} if meta_data else None)
self.save(
data=data,
meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None,
patch_index=patch_indice[i] if patch_indice is not None else None,
)
30 changes: 23 additions & 7 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Union
from typing import Dict, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -71,7 +71,12 @@ def __init__(

self._data_index = 0

def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save(
self,
data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_index: Optional[int] = None,
) -> None:
"""
Save data into a png file.
The meta_data could optionally have the following keys:
Expand All @@ -87,6 +92,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
Shape of the spatial dimensions (C,H,W).
C should be 1, 3 or 4
meta_data: the meta data information corresponding to the data.
patch_index: if the data is a patch of big image, need to append the patch index to filename.

Raises:
ValueError: When ``data`` channels is not one of [1, 3, 4].
Expand All @@ -102,8 +108,8 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir, self.data_root_dir)
filename = f"{filename}{self.output_ext}"
path = create_file_basename(self.output_postfix, filename, self.output_dir, self.data_root_dir, patch_index)
path = f"{path}{self.output_ext}"

if data.shape[0] == 1:
data = data.squeeze(0)
Expand All @@ -114,18 +120,28 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]

write_png(
np.asarray(data),
file_name=filename,
file_name=path,
output_spatial_shape=spatial_shape,
mode=self.mode,
scale=self.scale,
)

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
def save_batch(
self,
batch_data: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_indice: Optional[Sequence[int]] = None,
) -> None:
"""Save a batch of data into png format files.

Args:
batch_data: target batch data content that save into png format.
meta_data: every key-value in the meta_data is corresponding to a batch of data.
patch_indice: if the data is a patch of big image, need to append the patch index to filename.
"""
for i, data in enumerate(batch_data): # save a batch of files
self.save(data, {k: meta_data[k][i] for k in meta_data} if meta_data else None)
self.save(
data=data,
meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None,
patch_index=patch_indice[i] if patch_indice is not None else None,
)
8 changes: 7 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def create_file_basename(
input_file_name: str,
folder_path: str,
data_root_dir: str = "",
patch_index: Optional[int] = None,
) -> str:
"""
Utility function to create the path to the output file based on the input
Expand All @@ -623,6 +624,7 @@ def create_file_basename(
absolute path. This is used to compute `input_file_rel_path`, the relative path to the file from
`data_root_dir` to preserve folder structure when saving in case there are files in different
folders with the same file names.
patch_index: if not None, append the patch index to filename.
"""

# get the filename and directory
Expand All @@ -641,11 +643,15 @@ def create_file_basename(
if not os.path.exists(subfolder_path):
os.makedirs(subfolder_path)

if postfix:
if len(postfix) > 0:
# add the sub-folder plus the postfix name to become the file basename in the output path
output = os.path.join(subfolder_path, filename + "_" + postfix)
else:
output = os.path.join(subfolder_path, filename)

if patch_index is not None:
output += f"_{patch_index}"

return os.path.abspath(output)


Expand Down
7 changes: 5 additions & 2 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from monai.config import DtypeLike
from monai.transforms import SaveImage
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
Expand Down Expand Up @@ -143,5 +145,6 @@ def __call__(self, engine: Engine) -> None:
"""
meta_data = self.batch_transform(engine.state.batch)
engine_output = self.output_transform(engine.state.output)
self._saver(engine_output, meta_data)
patch_indice = engine.state.batch.get(Key.PATCH_INDEX, None)
self._saver(engine_output, meta_data, patch_indice)
self.logger.info("saved all the model outputs into files.")
10 changes: 9 additions & 1 deletion monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
map_binary_to_indices,
weighted_patch_samples,
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import InverseKeys

Expand Down Expand Up @@ -528,7 +529,12 @@ def randomize(self, data: Optional[Any] = None) -> None:
pass

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
return [self.cropper(data) for _ in range(self.num_samples)]
ret = []
for i in range(self.num_samples):
cropped = self.cropper(data)
cropped[Key.PATCH_INDEX] = i # type: ignore
ret.append(cropped)
return ret


class CropForegroundd(MapTransform, InvertibleTransform):
Expand Down Expand Up @@ -783,6 +789,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
# fill in the extra keys with unmodified data
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = data[key]
# add patch index in the meta data
results[i][Key.PATCH_INDEX] = i # type: ignore

return results

Expand Down
17 changes: 14 additions & 3 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,19 @@ def __init__(

self.save_batch = save_batch

def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None):
def __call__(
self,
img: Union[torch.Tensor, np.ndarray],
meta_data: Optional[Dict] = None,
patch_index=None, # type is Union[Sequence[int], int, None], can't be compatible with save and save_batch
):
"""
Args:
img: target data content that save into file.
meta_data: key-value pairs of meta_data corresponding to the data.
patch_index: if the data is a patch of big image, need to append the patch index to filename.
"""
if self.save_batch:
self.saver.save_batch(img, meta_data)
self.saver.save_batch(img, meta_data, patch_index)
else:
self.saver.save(img, meta_data)
self.saver.save(img, meta_data, patch_index)
10 changes: 7 additions & 3 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from monai.data.image_reader import ImageReader
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.transform import MapTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode

__all__ = [
"LoadImaged",
Expand Down Expand Up @@ -124,7 +126,9 @@ class SaveImaged(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`.

NB: image should include channel dimension: [B],C,H,W,[D].
Note:
Image should include channel dimension: [B],C,H,W,[D].
If the data is a patch of big image, will append the patch index to filename.

Args:
keys: keys of the corresponding items to be transformed.
Expand Down Expand Up @@ -225,7 +229,7 @@ def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
meta_data = d[f"{key}_{self.meta_key_postfix}"] if self.meta_key_postfix is not None else None
self._saver(img=d[key], meta_data=meta_data)
self._saver(img=d[key], meta_data=meta_data, patch_index=d.get(Key.PATCH_INDEX, None))
return d


Expand Down
1 change: 1 addition & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,4 @@ class ImageMetaKey:
"""

FILENAME_OR_OBJ = "filename_or_obj"
PATCH_INDEX = "patch_index"
8 changes: 8 additions & 0 deletions tests/test_file_basename.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,18 @@ def test_value(self):
expected = os.path.join(output_tmp, "test", "test")
self.assertEqual(result, expected)

result = create_file_basename("", "test.txt", output_tmp, "foo", 5)
expected = os.path.join(output_tmp, "test", "test_5")
self.assertEqual(result, expected)

result = create_file_basename("post", "test.tar.gz", output_tmp, "foo")
expected = os.path.join(output_tmp, "test", "test_post")
self.assertEqual(result, expected)

result = create_file_basename("post", "test.tar.gz", output_tmp, "foo", 8)
expected = os.path.join(output_tmp, "test", "test_post_8")
self.assertEqual(result, expected)


if __name__ == "__main__":
unittest.main()
9 changes: 7 additions & 2 deletions tests/test_handler_segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ def _train_func(engine, batch):
saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255)
saver.attach(engine)

data = [{"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}]
data = [
{
"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)],
"patch_index": list(range(8)),
}
]
engine.run(data, max_epochs=1)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext)
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + f"_{i}" + output_ext)
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

@parameterized.expand([TEST_CASE_0, TEST_CASE_1])
Expand Down
2 changes: 2 additions & 0 deletions tests/test_rand_crop_by_pos_neg_labeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def test_type_shape(self, input_param, input_data, expected_type, expected_shape
self.assertTupleEqual(result[0]["image"].shape, expected_shape)
self.assertTupleEqual(result[0]["extral"].shape, expected_shape)
self.assertTupleEqual(result[0]["label"].shape, expected_shape)
for i, item in enumerate(result):
self.assertEqual(item["patch_index"], i)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tests/test_rand_spatial_crop_samplesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last):
for item, expected in zip(result, expected_shape):
self.assertTupleEqual(item["img"].shape, expected)
self.assertTupleEqual(item["seg"].shape, expected)
for i, item in enumerate(result):
self.assertEqual(item["patch_index"], i)
np.testing.assert_allclose(item["img"], expected_last["img"])
np.testing.assert_allclose(item["seg"], expected_last["seg"])

Expand Down
17 changes: 15 additions & 2 deletions tests/test_save_imaged.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,20 @@
False,
]

TEST_CASE_6 = [
{
"img": torch.randint(0, 255, (1, 2, 3, 4)),
"img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"},
"patch_index": 6,
},
".nii.gz",
False,
False,
]


class TestSaveImaged(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
def test_saved_content(self, test_data, output_ext, resample, save_batch):
with tempfile.TemporaryDirectory() as tempdir:
trans = SaveImaged(
Expand All @@ -106,7 +117,9 @@ def test_saved_content(self, test_data, output_ext, resample, save_batch):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext)
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))
else:
filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext)
patch_index = test_data.get("patch_index", None)
patch_index = f"_{patch_index}" if patch_index is not None else ""
filepath = os.path.join("testfile0", "testfile0" + "_trans" + patch_index + output_ext)
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))


Expand Down