diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f9f7a4159c..959e42d6f9 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -98,3 +98,15 @@ Clara MMARs .. autofunction:: compute_isolated_tumor_cells .. autoclass:: PathologyProbNMS :members: + +.. automodule:: monai.apps.pathology.transforms.stain.array +.. autoclass:: ExtractHEStains + :members: +.. autoclass:: NormalizeHEStains + :members: + +.. automodule:: monai.apps.pathology.transforms.stain.dictionary +.. autoclass:: ExtractHEStainsd + :members: +.. autoclass:: NormalizeHEStainsd + :members: diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index 203e1a80d7..0ada8fe51b 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -12,4 +12,13 @@ from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCacheDataset from .handlers import ProbMapProducer from .metrics import LesionFROC +from .transforms.stain.array import ExtractHEStains, NormalizeHEStains +from .transforms.stain.dictionary import ( + ExtractHEStainsd, + ExtractHEStainsD, + ExtractHEStainsDict, + NormalizeHEStainsd, + NormalizeHEStainsD, + NormalizeHEStainsDict, +) from .utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py new file mode 100644 index 0000000000..0df016244b --- /dev/null +++ b/monai/apps/pathology/transforms/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .stain.array import ExtractHEStains, NormalizeHEStains +from .stain.dictionary import ( + ExtractHEStainsd, + ExtractHEStainsD, + ExtractHEStainsDict, + NormalizeHEStainsd, + NormalizeHEStainsD, + NormalizeHEStainsDict, +) diff --git a/monai/apps/pathology/transforms/stain/__init__.py b/monai/apps/pathology/transforms/stain/__init__.py new file mode 100644 index 0000000000..824f40a579 --- /dev/null +++ b/monai/apps/pathology/transforms/stain/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .array import ExtractHEStains, NormalizeHEStains +from .dictionary import ( + ExtractHEStainsd, + ExtractHEStainsD, + ExtractHEStainsDict, + NormalizeHEStainsd, + NormalizeHEStainsD, + NormalizeHEStainsDict, +) diff --git a/monai/apps/pathology/transforms/stain/array.py b/monai/apps/pathology/transforms/stain/array.py new file mode 100644 index 0000000000..ccddc6b243 --- /dev/null +++ b/monai/apps/pathology/transforms/stain/array.py @@ -0,0 +1,196 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import numpy as np + +from monai.transforms.transform import Transform + + +class ExtractHEStains(Transform): + """Class to extract a target stain from an image, using stain deconvolution (see Note). + + Args: + tli: transmitted light intensity. Defaults to 240. + alpha: tolerance in percentile for the pseudo-min (alpha percentile) + and pseudo-max (100 - alpha percentile). Defaults to 1. + beta: absorbance threshold for transparent pixels. Defaults to 0.15 + max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). + Defaults to (1.9705, 1.0308). + + Note: + For more information refer to: + - the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf + - the previous implementations: + + - MATLAB: https://github.com/mitkovetta/staining-normalization + - Python: https://github.com/schaugf/HEnorm_python + """ + + def __init__( + self, + tli: float = 240, + alpha: float = 1, + beta: float = 0.15, + max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + ) -> None: + self.tli = tli + self.alpha = alpha + self.beta = beta + self.max_cref = np.array(max_cref) + + def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray: + """Perform Stain Deconvolution and return stain matrix for the image. + + Args: + img: uint8 RGB image to perform stain deconvolution on + + Return: + he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) + """ + # check image type and vlues + if not isinstance(image, np.ndarray): + raise TypeError("Image must be of type numpy.ndarray.") + if image.min() < 0: + raise ValueError("Image should not have negative values.") + if image.max() > 255: + raise ValueError("Image should not have values greater than 255.") + + # reshape image and calculate absorbance + image = image.reshape((-1, 3)) + image = image.astype(np.float32) + 1.0 + absorbance = -np.log(image.clip(max=self.tli) / self.tli) + + # remove transparent pixels + absorbance_hat = absorbance[np.all(absorbance > self.beta, axis=1)] + if len(absorbance_hat) == 0: + raise ValueError("All pixels of the input image are below the absorbance threshold.") + + # compute eigenvectors + _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32)) + + # project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues + t_hat = absorbance_hat.dot(eigvecs[:, 1:3]) + + # find the min and max vectors and project back to absorbance space + phi = np.arctan2(t_hat[:, 1], t_hat[:, 0]) + min_phi = np.percentile(phi, self.alpha) + max_phi = np.percentile(phi, 100 - self.alpha) + v_min = eigvecs[:, 1:3].dot(np.array([(np.cos(min_phi), np.sin(min_phi))], dtype=np.float32).T) + v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T) + + # a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second + if v_min[0] > v_max[0]: + he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T + else: + he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T + + return he + + def __call__(self, image: np.ndarray) -> np.ndarray: + """Perform stain extraction. + + Args: + image: uint8 RGB image to extract stain from + + return: + target_he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values) + """ + if not isinstance(image, np.ndarray): + raise TypeError("Image must be of type numpy.ndarray.") + + target_he = self._deconvolution_extract_stain(image) + return target_he + + +class NormalizeHEStains(Transform): + """Class to normalize patches/images to a reference or target image stain (see Note). + + Performs stain deconvolution of the source image using the ExtractHEStains + class, to obtain the stain matrix and calculate the stain concentration matrix + for the image. Then, performs the inverse Beer-Lambert transform to recreate the + patch using the target H&E stain matrix provided. If no target stain provided, a default + reference stain is used. Similarly, if no maximum stain concentrations are provided, a + reference maximum stain concentrations matrix is used. + + Args: + tli: transmitted light intensity. Defaults to 240. + alpha: tolerance in percentile for the pseudo-min (alpha percentile) and + pseudo-max (100 - alpha percentile). Defaults to 1. + beta: absorbance threshold for transparent pixels. Defaults to 0.15. + target_he: target stain matrix. Defaults to ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)). + max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). + Defaults to [1.9705, 1.0308]. + + Note: + For more information refer to: + - the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf + - the previous implementations: + + - MATLAB: https://github.com/mitkovetta/staining-normalization + - Python: https://github.com/schaugf/HEnorm_python + """ + + def __init__( + self, + tli: float = 240, + alpha: float = 1, + beta: float = 0.15, + target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), + max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + ) -> None: + self.tli = tli + self.target_he = np.array(target_he) + self.max_cref = np.array(max_cref) + self.stain_extractor = ExtractHEStains(tli=self.tli, alpha=alpha, beta=beta, max_cref=self.max_cref) + + def __call__(self, image: np.ndarray) -> np.ndarray: + """Perform stain normalization. + + Args: + image: uint8 RGB image/patch to be stain normalized, pixel values between 0 and 255 + + Return: + image_norm: stain normalized image/patch + """ + # check image type and vlues + if not isinstance(image, np.ndarray): + raise TypeError("Image must be of type numpy.ndarray.") + if image.min() < 0: + raise ValueError("Image should not have negative values.") + if image.max() > 255: + raise ValueError("Image should not have values greater than 255.") + + # extract stain of the image + he = self.stain_extractor(image) + + # reshape image and calculate absorbance + h, w, _ = image.shape + image = image.reshape((-1, 3)) + image = image.astype(np.float32) + 1.0 + absorbance = -np.log(image.clip(max=self.tli) / self.tli) + + # rows correspond to channels (RGB), columns to absorbance values + y = np.reshape(absorbance, (-1, 3)).T + + # determine concentrations of the individual stains + conc = np.linalg.lstsq(he, y, rcond=None)[0] + + # normalize stain concentrations + max_conc = np.array([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32) + tmp = np.divide(max_conc, self.max_cref, dtype=np.float32) + image_c = np.divide(conc, tmp[:, np.newaxis], dtype=np.float32) + + image_norm: np.ndarray = np.multiply(self.tli, np.exp(-self.target_he.dot(image_c)), dtype=np.float32) + image_norm[image_norm > 255] = 254 + image_norm = np.reshape(image_norm.T, (h, w, 3)).astype(np.uint8) + return image_norm diff --git a/monai/apps/pathology/transforms/stain/dictionary.py b/monai/apps/pathology/transforms/stain/dictionary.py new file mode 100644 index 0000000000..976af1e7c7 --- /dev/null +++ b/monai/apps/pathology/transforms/stain/dictionary.py @@ -0,0 +1,111 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of dictionary-based wrappers around the pathology transforms +defined in :py:class:`monai.apps.pathology.transforms.array`. + +Class names are ended with 'd' to denote dictionary-based transforms. +""" + +from typing import Dict, Hashable, Mapping, Union + +import numpy as np + +from monai.config import KeysCollection +from monai.transforms.transform import MapTransform + +from .array import ExtractHEStains, NormalizeHEStains + + +class ExtractHEStainsd(MapTransform): + """Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.ExtractHEStains`. + Class to extract a target stain from an image, using stain deconvolution. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + tli: transmitted light intensity. Defaults to 240. + alpha: tolerance in percentile for the pseudo-min (alpha percentile) + and pseudo-max (100 - alpha percentile). Defaults to 1. + beta: absorbance threshold for transparent pixels. Defaults to 0.15 + max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). + Defaults to (1.9705, 1.0308). + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + tli: float = 240, + alpha: float = 1, + beta: float = 0.15, + max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.extractor = ExtractHEStains(tli=tli, alpha=alpha, beta=beta, max_cref=max_cref) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.extractor(d[key]) + return d + + +class NormalizeHEStainsd(MapTransform): + """Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.NormalizeHEStains`. + + Class to normalize patches/images to a reference or target image stain. + + Performs stain deconvolution of the source image using the ExtractHEStains + class, to obtain the stain matrix and calculate the stain concentration matrix + for the image. Then, performs the inverse Beer-Lambert transform to recreate the + patch using the target H&E stain matrix provided. If no target stain provided, a default + reference stain is used. Similarly, if no maximum stain concentrations are provided, a + reference maximum stain concentrations matrix is used. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + tli: transmitted light intensity. Defaults to 240. + alpha: tolerance in percentile for the pseudo-min (alpha percentile) and + pseudo-max (100 - alpha percentile). Defaults to 1. + beta: absorbance threshold for transparent pixels. Defaults to 0.15. + target_he: target stain matrix. Defaults to None. + max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E). + Defaults to None. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + tli: float = 240, + alpha: float = 1, + beta: float = 0.15, + target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)), + max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308), + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.normalizer = NormalizeHEStains(tli=tli, alpha=alpha, beta=beta, target_he=target_he, max_cref=max_cref) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.normalizer(d[key]) + return d + + +ExtractHEStainsDict = ExtractHEStainsD = ExtractHEStainsd +NormalizeHEStainsDict = NormalizeHEStainsD = NormalizeHEStainsd diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py new file mode 100644 index 0000000000..1d74f485e9 --- /dev/null +++ b/tests/test_pathology_he_stain.py @@ -0,0 +1,243 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import ExtractHEStains, NormalizeHEStains + +# None inputs +EXTRACT_STAINS_TEST_CASE_0 = (None,) +EXTRACT_STAINS_TEST_CASE_00 = (None, None) +NORMALIZE_STAINS_TEST_CASE_0 = (None,) +NORMALIZE_STAINS_TEST_CASE_00: tuple = ({}, None, None) + +# input pixels with negative values +NEGATIVE_VALUE_TEST_CASE = [np.full((3, 2, 3), -1)] + +# input pixels with greater than 255 values +INVALID_VALUE_TEST_CASE = [np.full((3, 2, 3), 256)] + +# input pixels all transparent and below the beta absorbance threshold +EXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] + +# input pixels uniformly filled, but above beta absorbance threshold +EXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 3), 100)] + +# input pixels uniformly filled (different value), but above beta absorbance threshold +EXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 3), 150)] + +# input pixels uniformly filled with zeros, leading to two identical stains extracted +EXTRACT_STAINS_TEST_CASE_4 = [ + np.zeros((3, 2, 3)), + np.array([[0.0, 0.0], [0.70710678, 0.70710678], [0.70710678, 0.70710678]]), +] + +# input pixels not uniformly filled, leading to two different stains extracted +EXTRACT_STAINS_TEST_CASE_5 = [ + np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]), + np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]), +] + + +# input pixels all transparent and below the beta absorbance threshold +NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] + +# input pixels uniformly filled with zeros, and target stain matrix provided +NORMALIZE_STAINS_TEST_CASE_2 = [{"target_he": np.full((3, 2), 1)}, np.zeros((3, 2, 3)), np.full((3, 2, 3), 11)] + +# input pixels uniformly filled with zeros, and target stain matrix not provided +NORMALIZE_STAINS_TEST_CASE_3 = [ + {}, + np.zeros((3, 2, 3)), + np.array([[[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]), +] + +# input pixels not uniformly filled +NORMALIZE_STAINS_TEST_CASE_4 = [ + {"target_he": np.full((3, 2), 1)}, + np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]), + np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]), +] + + +class TestExtractHEStains(unittest.TestCase): + @parameterized.expand( + [ + NEGATIVE_VALUE_TEST_CASE, + INVALID_VALUE_TEST_CASE, + EXTRACT_STAINS_TEST_CASE_0, + EXTRACT_STAINS_TEST_CASE_1, + ] + ) + def test_transparent_image(self, image): + """ + Test HE stain extraction on an image that comprises + only transparent pixels - pixels with absorbance below the + beta absorbance threshold. A ValueError should be raised, + since once the transparent pixels are removed, there are no + remaining pixels to compute eigenvectors. + """ + if image is None: + with self.assertRaises(TypeError): + ExtractHEStains()(image) + else: + with self.assertRaises(ValueError): + ExtractHEStains()(image) + + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_2, EXTRACT_STAINS_TEST_CASE_3]) + def test_identical_result_vectors(self, image): + """ + Test HE stain extraction on input images that are + uniformly filled with pixels that have absorbance above the + beta absorbance threshold. Since input image is uniformly filled, + the two extracted stains should have the same RGB values. So, + we assert that the first column is equal to the second column + of the returned stain matrix. + """ + if image is None: + with self.assertRaises(TypeError): + ExtractHEStains()(image) + else: + result = ExtractHEStains()(image) + np.testing.assert_array_equal(result[:, 0], result[:, 1]) + + @parameterized.expand( + [ + EXTRACT_STAINS_TEST_CASE_00, + EXTRACT_STAINS_TEST_CASE_4, + EXTRACT_STAINS_TEST_CASE_5, + ] + ) + def test_result_value(self, image, expected_data): + """ + Test that an input image returns an expected stain matrix. + + For test case 4: + - a uniformly filled input image should result in + eigenvectors [[1,0,0],[0,1,0],[0,0,1]] + - phi should be an array containing only values of + arctan(1) since the ratio between the eigenvectors + corresponding to the two largest eigenvalues is 1 + - maximum phi and minimum phi should thus be arctan(1) + - thus, maximum vector and minimum vector should be + [[0],[0.70710677],[0.70710677]] + - the resulting extracted stain should be + [[0,0],[0.70710678,0.70710678],[0.70710678,0.70710678]] + + For test case 5: + - the non-uniformly filled input image should result in + eigenvectors [[0,0,1],[1,0,0],[0,1,0]] + - maximum phi and minimum phi should thus be 0.785 and + 0.188 respectively + - thus, maximum vector and minimum vector should be + [[0.18696113],[0],[0.98236734]] and + [[0.70710677],[0],[0.70710677]] respectively + - the resulting extracted stain should be + [[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]] + """ + if image is None: + with self.assertRaises(TypeError): + ExtractHEStains()(image) + else: + result = ExtractHEStains()(image) + np.testing.assert_allclose(result, expected_data) + + +class TestNormalizeHEStains(unittest.TestCase): + @parameterized.expand( + [ + NEGATIVE_VALUE_TEST_CASE, + INVALID_VALUE_TEST_CASE, + NORMALIZE_STAINS_TEST_CASE_0, + NORMALIZE_STAINS_TEST_CASE_1, + ] + ) + def test_transparent_image(self, image): + """ + Test HE stain normalization on an image that comprises + only transparent pixels - pixels with absorbance below the + beta absorbance threshold. A ValueError should be raised, + since once the transparent pixels are removed, there are no + remaining pixels to compute eigenvectors. + """ + if image is None: + with self.assertRaises(TypeError): + NormalizeHEStains()(image) + else: + with self.assertRaises(ValueError): + NormalizeHEStains()(image) + + @parameterized.expand( + [ + NORMALIZE_STAINS_TEST_CASE_00, + NORMALIZE_STAINS_TEST_CASE_2, + NORMALIZE_STAINS_TEST_CASE_3, + NORMALIZE_STAINS_TEST_CASE_4, + ] + ) + def test_result_value(self, argments, image, expected_data): + """ + Test that an input image returns an expected normalized image. + + For test case 2: + - This case tests calling the stain normalizer, after the + _deconvolution_extract_conc function. This is because the normalized + concentration returned for each pixel is the same as the reference + maximum stain concentrations in the case that the image is uniformly + filled, as in this test case. This is because the maximum concentration + for each stain is the same as each pixel's concentration. + - Thus, the normalized concentration matrix should be a (2, 6) matrix + with the first row having all values of 1.9705, second row all 1.0308. + - Taking the matrix product of the target stain matrix and the concentration + matrix, then using the inverse Beer-Lambert transform to obtain the RGB + image from the absorbance image, and finally converting to uint8, + we get that the stain normalized image should be a matrix of + dims (3, 2, 3), with all values 11. + + For test case 3: + - This case also tests calling the stain normalizer, after the + _deconvolution_extract_conc function returns the image concentration + matrix. + - As in test case 2, the normalized concentration matrix should be a (2, 6) matrix + with the first row having all values of 1.9705, second row all 1.0308. + - Taking the matrix product of the target default stain matrix and the concentration + matrix, then using the inverse Beer-Lambert transform to obtain the RGB + image from the absorbance image, and finally converting to uint8, + we get that the stain normalized image should be [[[63, 25, 60], [63, 25, 60]], + [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]] + + For test case 4: + - For this non-uniformly filled image, the stain extracted should be + [[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the + ExtractHEStains class. Solving the linear least squares problem (since + absorbance matrix = stain matrix * concentration matrix), we obtain the concentration + matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508], + [5.8022, 0, 0, 0, 0, 0]] + - Normalizing the concentration matrix, taking the matrix product of the + target stain matrix and the concentration matrix, using the inverse + Beer-Lambert transform to obtain the RGB image from the absorbance + image, and finally converting to uint8, we get that the stain normalized + image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], + [[33, 33, 33], [33, 33, 33]]] + """ + if image is None: + with self.assertRaises(TypeError): + NormalizeHEStains()(image) + else: + result = NormalizeHEStains(**argments)(image) + np.testing.assert_allclose(result, expected_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pathology_he_stain_dict.py b/tests/test_pathology_he_stain_dict.py new file mode 100644 index 0000000000..8d51579cb2 --- /dev/null +++ b/tests/test_pathology_he_stain_dict.py @@ -0,0 +1,227 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import ExtractHEStainsD, NormalizeHEStainsD + +# None inputs +EXTRACT_STAINS_TEST_CASE_0 = (None,) +EXTRACT_STAINS_TEST_CASE_00 = (None, None) +NORMALIZE_STAINS_TEST_CASE_0 = (None,) +NORMALIZE_STAINS_TEST_CASE_00: tuple = ({}, None, None) + +# input pixels all transparent and below the beta absorbance threshold +EXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] + +# input pixels uniformly filled, but above beta absorbance threshold +EXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 3), 100)] + +# input pixels uniformly filled (different value), but above beta absorbance threshold +EXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 3), 150)] + +# input pixels uniformly filled with zeros, leading to two identical stains extracted +EXTRACT_STAINS_TEST_CASE_4 = [ + np.zeros((3, 2, 3)), + np.array([[0.0, 0.0], [0.70710678, 0.70710678], [0.70710678, 0.70710678]]), +] + +# input pixels not uniformly filled, leading to two different stains extracted +EXTRACT_STAINS_TEST_CASE_5 = [ + np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]), + np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]), +] + +# input pixels all transparent and below the beta absorbance threshold +NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)] + +# input pixels uniformly filled with zeros, and target stain matrix provided +NORMALIZE_STAINS_TEST_CASE_2 = [{"target_he": np.full((3, 2), 1)}, np.zeros((3, 2, 3)), np.full((3, 2, 3), 11)] + +# input pixels uniformly filled with zeros, and target stain matrix not provided +NORMALIZE_STAINS_TEST_CASE_3 = [ + {}, + np.zeros((3, 2, 3)), + np.array([[[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]), +] + +# input pixels not uniformly filled +NORMALIZE_STAINS_TEST_CASE_4 = [ + {"target_he": np.full((3, 2), 1)}, + np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]), + np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]), +] + + +class TestExtractHEStainsD(unittest.TestCase): + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1]) + def test_transparent_image(self, image): + """ + Test HE stain extraction on an image that comprises + only transparent pixels - pixels with absorbance below the + beta absorbance threshold. A ValueError should be raised, + since once the transparent pixels are removed, there are no + remaining pixels to compute eigenvectors. + """ + key = "image" + if image is None: + with self.assertRaises(TypeError): + ExtractHEStainsD([key])({key: image}) + else: + with self.assertRaises(ValueError): + ExtractHEStainsD([key])({key: image}) + + @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_2, EXTRACT_STAINS_TEST_CASE_3]) + def test_identical_result_vectors(self, image): + """ + Test HE stain extraction on input images that are + uniformly filled with pixels that have absorbance above the + beta absorbance threshold. Since input image is uniformly filled, + the two extracted stains should have the same RGB values. So, + we assert that the first column is equal to the second column + of the returned stain matrix. + """ + key = "image" + if image is None: + with self.assertRaises(TypeError): + ExtractHEStainsD([key])({key: image}) + else: + result = ExtractHEStainsD([key])({key: image}) + np.testing.assert_array_equal(result[key][:, 0], result[key][:, 1]) + + @parameterized.expand( + [ + EXTRACT_STAINS_TEST_CASE_00, + EXTRACT_STAINS_TEST_CASE_4, + EXTRACT_STAINS_TEST_CASE_5, + ] + ) + def test_result_value(self, image, expected_data): + """ + Test that an input image returns an expected stain matrix. + + For test case 4: + - a uniformly filled input image should result in + eigenvectors [[1,0,0],[0,1,0],[0,0,1]] + - phi should be an array containing only values of + arctan(1) since the ratio between the eigenvectors + corresponding to the two largest eigenvalues is 1 + - maximum phi and minimum phi should thus be arctan(1) + - thus, maximum vector and minimum vector should be + [[0],[0.70710677],[0.70710677]] + - the resulting extracted stain should be + [[0,0],[0.70710678,0.70710678],[0.70710678,0.70710678]] + + For test case 5: + - the non-uniformly filled input image should result in + eigenvectors [[0,0,1],[1,0,0],[0,1,0]] + - maximum phi and minimum phi should thus be 0.785 and + 0.188 respectively + - thus, maximum vector and minimum vector should be + [[0.18696113],[0],[0.98236734]] and + [[0.70710677],[0],[0.70710677]] respectively + - the resulting extracted stain should be + [[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]] + """ + key = "image" + if image is None: + with self.assertRaises(TypeError): + ExtractHEStainsD([key])({key: image}) + else: + result = ExtractHEStainsD([key])({key: image}) + np.testing.assert_allclose(result[key], expected_data) + + +class TestNormalizeHEStainsD(unittest.TestCase): + @parameterized.expand([NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1]) + def test_transparent_image(self, image): + """ + Test HE stain normalization on an image that comprises + only transparent pixels - pixels with absorbance below the + beta absorbance threshold. A ValueError should be raised, + since once the transparent pixels are removed, there are no + remaining pixels to compute eigenvectors. + """ + key = "image" + if image is None: + with self.assertRaises(TypeError): + NormalizeHEStainsD([key])({key: image}) + else: + with self.assertRaises(ValueError): + NormalizeHEStainsD([key])({key: image}) + + @parameterized.expand( + [ + NORMALIZE_STAINS_TEST_CASE_00, + NORMALIZE_STAINS_TEST_CASE_2, + NORMALIZE_STAINS_TEST_CASE_3, + NORMALIZE_STAINS_TEST_CASE_4, + ] + ) + def test_result_value(self, argments, image, expected_data): + """ + Test that an input image returns an expected normalized image. + + For test case 2: + - This case tests calling the stain normalizer, after the + _deconvolution_extract_conc function. This is because the normalized + concentration returned for each pixel is the same as the reference + maximum stain concentrations in the case that the image is uniformly + filled, as in this test case. This is because the maximum concentration + for each stain is the same as each pixel's concentration. + - Thus, the normalized concentration matrix should be a (2, 6) matrix + with the first row having all values of 1.9705, second row all 1.0308. + - Taking the matrix product of the target stain matrix and the concentration + matrix, then using the inverse Beer-Lambert transform to obtain the RGB + image from the absorbance image, and finally converting to uint8, + we get that the stain normalized image should be a matrix of + dims (3, 2, 3), with all values 11. + + For test case 3: + - This case also tests calling the stain normalizer, after the + _deconvolution_extract_conc function returns the image concentration + matrix. + - As in test case 2, the normalized concentration matrix should be a (2, 6) matrix + with the first row having all values of 1.9705, second row all 1.0308. + - Taking the matrix product of the target default stain matrix and the concentration + matrix, then using the inverse Beer-Lambert transform to obtain the RGB + image from the absorbance image, and finally converting to uint8, + we get that the stain normalized image should be [[[63, 25, 60], [63, 25, 60]], + [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]] + + For test case 4: + - For this non-uniformly filled image, the stain extracted should be + [[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the + ExtractHEStains class. Solving the linear least squares problem (since + absorbance matrix = stain matrix * concentration matrix), we obtain the concentration + matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508], + [5.8022, 0, 0, 0, 0, 0]] + - Normalizing the concentration matrix, taking the matrix product of the + target stain matrix and the concentration matrix, using the inverse + Beer-Lambert transform to obtain the RGB image from the absorbance + image, and finally converting to uint8, we get that the stain normalized + image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], + [[33, 33, 33], [33, 33, 33]]] + """ + key = "image" + if image is None: + with self.assertRaises(TypeError): + NormalizeHEStainsD([key])({key: image}) + else: + result = NormalizeHEStainsD([key], **argments)({key: image}) + np.testing.assert_allclose(result[key], expected_data) + + +if __name__ == "__main__": + unittest.main()