diff --git a/monailabel/interfaces/tasks/infer.py b/monailabel/interfaces/tasks/infer.py index b0f329fcb..25870a2c8 100644 --- a/monailabel/interfaces/tasks/infer.py +++ b/monailabel/interfaces/tasks/infer.py @@ -434,6 +434,8 @@ def writer(self, data, extension=None, dtype=None): data["result_extension"] = extension if dtype is not None: data["result_dtype"] = dtype + if self.labels is not None: + data["labels"] = self.labels writer = Writer(label=self.output_label_key, json=self.output_json_key) return writer(data) diff --git a/monailabel/transform/writer.py b/monailabel/transform/writer.py index c65c4bdf0..1d6728a8f 100644 --- a/monailabel/transform/writer.py +++ b/monailabel/transform/writer.py @@ -10,8 +10,10 @@ # limitations under the License. import logging import tempfile +from typing import Any, Dict, Iterable, List, Optional import itk +import nrrd import numpy as np from monai.data import write_nifti @@ -58,6 +60,85 @@ def write_itk(image_np, output_file, affine, dtype, compress): itk.imwrite(result_image, output_file, compress) +def write_seg_nrrd( + image_np: np.ndarray, + output_file: str, + dtype: type, + affine: np.ndarray, + labels: List[str], + color_map: Optional[Dict[str, List[float]]] = None, + index_order: str = "C", + space: str = "left-posterior-superior", +) -> None: + """Write multi-channel seg.nrrd file. + + Args: + image_np: Image as numpy ndarray + output_file: Output file path that the seg.nrrd file should be saved to + dtype: numpy type e.g. float32 + affine: Affine matrix + labels: Labels of image segment which will be written to the nrrd header + color_map: Mapping from segment_name(str) to it's color e.g. {'heart': [255/255, 244/255, 209/255]} + index_order: Either 'C' or 'F' (see nrrd.write() documentation) + + Raises: + ValueError: In case affine is not provided + ValueError: In case labels are not provided + """ + image_np = image_np.transpose().copy() + if dtype: + image_np = image_np.astype(dtype) + + if not isinstance(labels, Iterable): + raise ValueError("Labels have to be defined, e.g. as a list") + + header: Dict[str, Any] = {} + for i, segment_name in enumerate(labels): + header.update( + { + f"Segment{i}_ID": segment_name, + f"Segment{i}_Name": segment_name, + } + ) + if color_map is not None: + header[f"Segment{i}_Color"] = " ".join(list(map(str, color_map[segment_name]))) + + if affine is None: + raise ValueError("Affine matrix has to be defined") + + kinds = ["list", "domain", "domain", "domain"] + + convert_aff_mat = np.diag([-1, -1, 1, 1]) + affine = convert_aff_mat @ affine + + _origin_key = (slice(-1), -1) + origin = affine[_origin_key] + + space_directions = np.array( + [ + [np.nan, np.nan, np.nan], + affine[0, :3], + affine[1, :3], + affine[2, :3], + ] + ) + + header.update( + { + "kinds": kinds, + "space directions": space_directions, + "space origin": origin, + "space": space, + } + ) + nrrd.write( + output_file, + image_np, + header=header, + index_order=index_order, + ) + + class Writer: def __init__( self, @@ -104,8 +185,18 @@ def __call__(self, data): output_file = tempfile.NamedTemporaryFile(suffix=ext).name logger.debug(f"Saving Image to: {output_file}") + if self.is_multichannel_image(image_np): + if ext != ".seg.nrrd": + logger.warning( + f"Using extension '{ext}' with multi-channel 4D label will probably fail" + + "Consider to use extension '.seg.nrrd'" + ) + labels = data.get("labels") + color_map = data.get("color_map") + logger.debug("Using write_seg_nrrd...") + write_seg_nrrd(image_np, output_file, dtype, affine, labels, color_map) # Issue with slicer:: https://discourse.itk.org/t/saving-non-orthogonal-volume-in-nifti-format/2760/22 - if self.nibabel and ext.lower() in [".nii", ".nii.gz"]: + elif self.nibabel and ext.lower() in [".nii", ".nii.gz"]: logger.debug("Using MONAI write_nifti...") write_nifti(image_np, output_file, affine=affine, output_dtype=dtype) else: @@ -113,6 +204,17 @@ def __call__(self, data): return output_file, output_json + def is_multichannel_image(self, image_np: np.ndarray) -> bool: + """Check if the provided image contains multiple channels + + Args: + image_np : Expected shape (channels, width, height, batch) + + Returns: + bool: If this is a multi-channel image or not + """ + return len(image_np.shape) == 4 and image_np.shape[0] > 1 + class ClassificationWriter: def __init__(self, label="pred", label_names=None): diff --git a/requirements.txt b/requirements.txt index c4ac60408..283506a45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,5 +26,6 @@ opencv-python-headless==4.5.5.64 Shapely==1.8.1.post1 girder_client==3.1.8 numpymaxflow==0.0.2 +pynrrd==0.4.2 #sudo apt-get install openslide-tools -y diff --git a/setup.cfg b/setup.cfg index 30febdfd3..d5516a109 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,7 @@ install_requires = Shapely==1.8.1.post1 girder_client==3.1.8 numpymaxflow==0.0.2 + pynrrd==0.4.2 [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/unit/transform/test_writer.py b/tests/unit/transform/test_writer.py index 3df178b7b..4251db783 100644 --- a/tests/unit/transform/test_writer.py +++ b/tests/unit/transform/test_writer.py @@ -2,6 +2,7 @@ import pathlib import unittest +import nrrd import numpy as np from parameterized import parameterized @@ -18,6 +19,17 @@ }, ] +CHANNELS = 2 +WIDTH = 15 +HEIGHT = 10 +MULTI_CHANNEL_DATA = np.zeros((CHANNELS, WIDTH, HEIGHT, 1)) + +COLOR_MAP = { + # according to getLabelColor() [https://github.com/Project-MONAI/MONAILabel/blob/6cc72c542c9bc6c5181af89550e7e397537d74e3/plugins/slicer/MONAILabel/MONAILabel.py#L1485] # noqa + "lung": [128 / 255, 174 / 255, 128 / 255], # green + "heart": [206 / 255, 110 / 255, 84 / 255], # red +} + class TestWriter(unittest.TestCase): @parameterized.expand([WRITER_DATA]) @@ -29,6 +41,32 @@ def test_nifti(self, args, input_data): file_ext = "".join(pathlib.Path(input_data["image_path"]).suffixes) self.assertIn(file_ext.lower(), [".nii", ".nii.gz"]) + @parameterized.expand([WRITER_DATA]) + def test_seg_nrrd(self, args, input_data): + args.update({"nibabel": False}) + input_data["pred"] = MULTI_CHANNEL_DATA + input_data["result_extension"] = ".seg.nrrd" + input_data["labels"] = ["heart", "lung"] + input_data["color_map"] = COLOR_MAP + + output_file, data = Writer(**args)(input_data) + self.assertEqual(os.path.exists(output_file), True) + arr_full, header = nrrd.read(output_file) + + self.assertEqual(arr_full.shape, (CHANNELS, WIDTH, HEIGHT, 1)) + + space_directions_expected = np.array( + [[np.nan, np.nan, np.nan], [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0]] + ) + self.assertTrue(np.array_equal(header["space directions"], space_directions_expected, equal_nan=True)) + + self.assertEqual(header["kinds"], ["list", "domain", "domain", "domain"]) + self.assertEqual(header["Segment1_ID"], "lung") + self.assertEqual(header["Segment1_Color"], " ".join(map(str, COLOR_MAP["lung"]))) + + file_ext = "".join(pathlib.Path(output_file).suffixes) + self.assertIn(file_ext.lower(), [".seg.nrrd"]) + @parameterized.expand([WRITER_DATA]) def test_itk(self, args, input_data): args.update({"nibabel": False})