diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..e6719d7 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,44 @@ +name: Build and upload to PyPI + +on: + push + +jobs: + build_wheels: + name: Build release + runs-on: ubuntu-18.04 + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: '3.x' + + - name: Install deps + run: | + pip install setuptools wheel + + - name: Build wheels + run: python setup.py sdist bdist_wheel + + - uses: actions/upload-artifact@v3 + with: + path: ./dist/* + + upload_pypi: + needs: build_wheels + runs-on: ubuntu-18.04 + + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') + + steps: + - uses: actions/download-artifact@v3 + with: + name: artifact + path: dist + + - uses: pypa/gh-action-pypi-publish@v1.5.0 + with: + user: __token__ + password: ${{ secrets.torchstain_deploy_token }} + \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests_full.yml similarity index 51% rename from .github/workflows/tests.yml rename to .github/workflows/tests_full.yml index 3849cee..e550c88 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests_full.yml @@ -3,14 +3,17 @@ name: tests on: push: branches: - - '*' + - main pull_request: branches: - - '*' + - main jobs: build: runs-on: ubuntu-18.04 + + if: startsWith(github.ref, 'refs/tags/v') != true + steps: - uses: actions/checkout@v1 - name: Set up Python 3.6 @@ -31,46 +34,14 @@ jobs: path: ${{github.workspace}}/dist/torchstain-*.whl if-no-files-found: error - test: - needs: build - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [windows-2019, ubuntu-18.04, macos-11] - python-version: [3.6, 3.7, 3.8, 3.9] - - steps: - - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Download artifact - uses: actions/download-artifact@master - with: - name: "Python wheel" - - - name: Install deps and wheel - run: | - pip install tensorflow torch - pip install --find-links=${{github.workspace}} torchstain - - - name: Install test dependencies - run: pip install opencv-python torchvision scikit-image pytest - - - name: Run tests - run: | - pytest -v tests/test_utils.py - pytest -v tests/test_normalizers.py - test-tf: needs: build runs-on: ${{ matrix.os }} strategy: matrix: os: [ windows-2019, ubuntu-18.04, macos-11 ] - python-version: [ 3.7 ] + python-version: [ 3.7, 3.8, 3.9 ] + tf-version: [2.7.0, 2.8.0, 2.9.0] steps: - uses: actions/checkout@v1 @@ -84,16 +55,16 @@ jobs: with: name: "Python wheel" - - name: Install deps and wheel + - name: Install dependencies run: | - pip install tensorflow - pip install --find-links=${{github.workspace}} torchstain + pip install tensorflow==${{ matrix.tf-version }} protobuf==3.20.* opencv-python-headless scikit-image + pip install pytest - - name: Install test dependencies - run: pip install opencv-python scikit-image pytest + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain - name: Run tests - run: pytest -v tests/test_tf_normalizer.py + run: pytest -vs tests/test_tf.py test-torch: needs: build @@ -101,7 +72,13 @@ jobs: strategy: matrix: os: [ windows-2019, ubuntu-18.04, macos-11 ] - python-version: [ 3.7 ] + python-version: [ 3.6, 3.7, 3.8, 3.9 ] + pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0] + exclude: + - python-version: 3.6 + pytorch-version: 1.11.0 + - python-version: 3.6 + pytorch-version: 1.12.0 steps: - uses: actions/checkout@v1 @@ -115,13 +92,13 @@ jobs: with: name: "Python wheel" - - name: Install deps and wheel wheel + - name: Install dependencies run: | - pip install torch - pip install --find-links=${{github.workspace}} torchstain + pip install torch==${{ matrix.pytorch-version }} torchvision opencv-python-headless scikit-image + pip install pytest - - name: Install test dependencies - run: pip install opencv-python torchvision scikit-image pytest + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain - name: Run tests - run: pytest -v tests/test_torch_normalizer.py + run: pytest -vs tests/test_torch.py diff --git a/.github/workflows/tests_quick.yml b/.github/workflows/tests_quick.yml new file mode 100644 index 0000000..3e08828 --- /dev/null +++ b/.github/workflows/tests_quick.yml @@ -0,0 +1,86 @@ +name: tests + +on: + push: + branches-ignore: + - main + pull_request: + branches-ignore: + - main + +jobs: + build: + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.6 + uses: actions/setup-python@v2 + with: + python-version: 3.6 + + - name: Install dependencies + run: pip install wheel setuptools + + - name: Build wheel + run: python setup.py bdist_wheel + + - name: Upload Python wheel + uses: actions/upload-artifact@v2 + with: + name: Python wheel + path: ${{github.workspace}}/dist/torchstain-*.whl + if-no-files-found: error + + test-tf: + needs: build + runs-on: ubuntu-18.04 + + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Download artifact + uses: actions/download-artifact@master + with: + name: "Python wheel" + + - name: Install dependencies + run: | + pip install tensorflow protobuf==3.20.* opencv-python-headless scikit-image + pip install pytest + + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain + + - name: Run tests + run: pytest -vs tests/test_tf.py + + test-torch: + needs: build + runs-on: ubuntu-18.04 + + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Download artifact + uses: actions/download-artifact@master + with: + name: "Python wheel" + + - name: Install dependencies + run: | + pip install torch torchvision opencv-python-headless scikit-image + pip install pytest + + - name: Install wheel + run: pip install --find-links=${{github.workspace}} torchstain + + - name: Run tests + run: pytest -vs tests/test_torch.py diff --git a/README.md b/README.md index 6d45a3e..d16b5ea 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,13 @@ [![License](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) [![tests](https://github.com/EIDOSLAB/torchstain/workflows/tests/badge.svg)](https://github.com/EIDOSLAB/torchstain/actions) [![Pip Downloads](https://img.shields.io/pypi/dm/torchstain?label=pip%20downloads&logo=python)](https://pypi.org/project/torchstain/) +[![DOI](https://zenodo.org/badge/323590093.svg)](https://zenodo.org/badge/latestdoi/323590093) GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy. Normalization algorithms currently implemented: - Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python)) +- Reinhard et al. [\[2\]](#reference) ## Installation @@ -15,6 +17,8 @@ Normalization algorithms currently implemented: pip install torchstain ``` +To install a specific backend use either ```torchstain[torch]``` or ```torchstain[tf]```. The numpy backend is included by default in both. + ## Example Usage ```python @@ -38,26 +42,62 @@ t_to_transform = T(to_transform) norm, H, E = normalizer.normalize(I=t_to_transform, stains=True) ``` -![alt text](result.png) +![alt text](data/result.png) + +## Implemented algorithms +| Algorithm | numpy | torch | tensorflow | +|-|-|-|-| +| Macenko | ✓ | ✓ | ✓ | +| Reinhard | ✓ | ✓ | ✓ | ## Backend comparison Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz - -| size | numpy avg. time | numpy tot. time | torch avg. time | torch tot. time | -|--------|-------------------|-------------------|-------------------|-------------------| -| 224x224 | 0.0323s ± 0.0032 | 0.3231s | 0.0234s ± 0.0384 | 0.2340s | -| 448x448 | 0.1228s ± 0.0042 | 1.2280s | 0.0395s ± 0.0168 | 0.3954s | -| 672x672 | 0.2653s ± 0.0106 | 2.6534s | 0.0753s ± 0.0157 | 0.7527s | -| 896x896 | 0.4940s ± 0.0208 | 4.9397s | 0.1262s ± 0.0159 | 1.2622s | -| 1120x1120 | 0.6888s ± 0.0081 | 6.8883s | 0.2002s ± 0.0141 | 2.0021s | -| 1344x1344 | 1.0145s ± 0.0089 | 10.1448s | 0.2703s ± 0.0136 | 2.7026s | -| 1568x1568 | 1.2620s ± 0.0133 | 12.6200s | 0.3680s ± 0.0128 | 3.6795s | -| 1792x1792 | 1.4289s ± 0.0128 | 14.2886s | 0.5968s ± 0.0160 | 5.9676s | - +| size | numpy avg. time | torch avg. time | tf avg. time | +|--------|-------------------|-------------------|------------------| +| 224 | 0.0182s ± 0.0016 | 0.0180s ± 0.0390 | 0.0048s ± 0.0002 | +| 448 | 0.0880s ± 0.0224 | 0.0283s ± 0.0172 | 0.0210s ± 0.0025 | +| 672 | 0.1810s ± 0.0139 | 0.0463s ± 0.0301 | 0.0354s ± 0.0018 | +| 896 | 0.3013s ± 0.0377 | 0.0820s ± 0.0329 | 0.0713s ± 0.0008 | +| 1120 | 0.4694s ± 0.0350 | 0.1321s ± 0.0237 | 0.1036s ± 0.0042 | +| 1344 | 0.6640s ± 0.0553 | 0.1665s ± 0.0026 | 0.1663s ± 0.0021 | +| 1568 | 1.1935s ± 0.0739 | 0.2590s ± 0.0088 | 0.2531s ± 0.0031 | +| 1792 | 1.4523s ± 0.0207 | 0.3402s ± 0.0114 | 0.3080s ± 0.0188 | ## Reference - [1] Macenko, Marc, et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009. +- [2] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001. + +## Citing + +If you find this software useful for your research, please cite it as: + +```bibtex +@software{barbano2022torchstain, + author = {Carlo Alberto Barbano and + André Pedersen}, + title = {EIDOSLAB/torchstain: v1.2.0-stable}, + month = aug, + year = 2022, + publisher = {Zenodo}, + version = {v1.2.0-stable}, + doi = {10.5281/zenodo.6979540}, + url = {https://doi.org/10.5281/zenodo.6979540} +} +``` + +Torchstain was originally developed within the [UNITOPATHO](https://github.com/EIDOSLAB/UNITOPATHO) data collection, which you can cite as: + +```bibtex +@inproceedings{barbano2021unitopatho, + title={UniToPatho, a labeled histopathological dataset for colorectal polyps classification and adenoma dysplasia grading}, + author={Barbano, Carlo Alberto and Perlo, Daniele and Tartaglione, Enzo and Fiandrotti, Attilio and Bertero, Luca and Cassoni, Paola and Grangetto, Marco}, + booktitle={2021 IEEE International Conference on Image Processing (ICIP)}, + pages={76--80}, + year={2021}, + organization={IEEE} +} +``` diff --git a/compare.py b/compare.py index 0fbf4f8..cacab8c 100644 --- a/compare.py +++ b/compare.py @@ -14,24 +14,26 @@ def measure(size, N): target = cv2.resize(cv2.cvtColor(cv2.imread("./data/target.png"), cv2.COLOR_BGR2RGB), (size, size)) to_transform = cv2.resize(cv2.cvtColor(cv2.imread("./data/source.png"), cv2.COLOR_BGR2RGB), (size, size)) - normalizer = torchstain.MacenkoNormalizer(backend='numpy') + normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') normalizer.fit(target) - T = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Lambda(lambda x: x*255) ]) - torch_normalizer = torchstain.MacenkoNormalizer(backend='torch') + torch_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch') torch_normalizer.fit(T(target)) + tf_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='tensorflow') + tf_normalizer.fit(T(target)) + t_to_transform = T(to_transform) t_np = [] start_np = time.perf_counter() - for i in range(N): + for _ in range(N): tic = time.perf_counter() _ = normalizer.normalize(to_transform) toc = time.perf_counter() @@ -42,7 +44,7 @@ def measure(size, N): t_torch = [] start_torch = time.perf_counter() - for i in range(N): + for _ in range(N): tic = time.perf_counter() _ = torch_normalizer.normalize(t_to_transform) toc = time.perf_counter() @@ -50,18 +52,32 @@ def measure(size, N): end_torch = time.perf_counter() t_torch = np.array(t_torch) + + t_tf = [] + start_tf = time.perf_counter() + for _ in range(N): + tic = time.perf_counter() + _ = torch_normalizer.normalize(t_to_transform) + toc = time.perf_counter() + t_tf.append(toc-tic) + end_tf = time.perf_counter() + t_tf = np.array(t_tf) + """ print(f'Results of {N} runs:') print(f'numpy: {t_np.mean():.4f}s ± {t_np.std():.4f} (tot: {end_np-start_np:.4f}s)') print(f'torch: {t_torch.mean():.4f}s ± {t_torch.std():.4f} (tot: {end_torch-start_torch:.4f}s)') """ - return t_np, end_np-start_np, t_torch, end_torch-start_torch + return t_np, end_np-start_np, t_torch, end_torch-start_torch, t_tf, end_tf-start_tf table = [] for size in [224, 448, 672, 896, 1120, 1344, 1568, 1792]: - t_np, tot_np, t_torch, tot_torch = measure(size, N=10) - row = [size, f'{t_np.mean():.4f}s ± {t_np.std():.4f}', f'{tot_np:.4f}s', f'{t_torch.mean():.4f}s ± {t_torch.std():.4f}', f'{tot_torch:.4f}s'] + t_np, tot_np, t_torch, tot_torch, t_tf, tot_tf = measure(size, N=10) + # row = [size, f'{t_np.mean():.4f}s ± {t_np.std():.4f}', f'{tot_np:.4f}s', f'{t_torch.mean():.4f}s ± {t_torch.std():.4f}', f'{tot_torch:.4f}s'] + row = [size, f'{t_np.mean():.4f}s ± {t_np.std():.4f}', f'{t_torch.mean():.4f}s ± {t_torch.std():.4f}', f'{t_tf.mean():.4f}s ± {t_tf.std():.4f}'] table.append(row) -print(tabulate(table, headers=['size', 'numpy avg. time', 'numpy tot. time', 'torch avg. time', 'torch tot. time'], tablefmt='github')) +# print(tabulate(table, headers=['size', 'numpy avg. time', 'numpy tot. time', 'torch avg. time', 'torch tot. time'], tablefmt='github')) +print(tabulate(table, headers=['size', 'numpy avg. time', 'torch avg. time', 'tf avg. time'], tablefmt='github')) + diff --git a/result.png b/data/result.png similarity index 100% rename from result.png rename to data/result.png diff --git a/setup.py b/setup.py index 7eb0b9e..d27acf7 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,8 @@ setup( name='torchstain', - version='1.1.0', - description='Pytorch stain normalization utils', + version='1.2.0', + description='Stain normalization tools for histological analysis and computational pathology', long_description=README, long_description_content_type='text/markdown', url='https://github.com/EIDOSlab/torchstain', diff --git a/tests/test_color_conv.py b/tests/test_color_conv.py new file mode 100644 index 0000000..7261307 --- /dev/null +++ b/tests/test_color_conv.py @@ -0,0 +1,15 @@ +from torchstain.numpy.utils.rgb2lab import rgb2lab +from torchstain.numpy.utils.lab2rgb import lab2rgb +import numpy as np +import cv2 +import os + +def test_rgb_to_lab(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + img = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + reconstructed_img = lab2rgb(rgb2lab(img)) + val = np.mean(np.abs(reconstructed_img - img)) + print("MAE:", val) + assert val < 0.1 diff --git a/tests/test_normalizers.py b/tests/test_normalizers.py deleted file mode 100644 index 5769b96..0000000 --- a/tests/test_normalizers.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import cv2 -import torchstain -import torch -from torchvision import transforms -import time -from skimage.metrics import structural_similarity as ssim -import numpy as np - -def test_normalize_all(): - size = 1024 - curr_file_path = os.path.dirname(os.path.realpath(__file__)) - target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) - to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) - - # setup preprocessing and preprocess image to be normalized - T = transforms.Compose([ - transforms.ToTensor(), - transforms.Lambda(lambda x: x*255) - ]) - t_to_transform = T(to_transform) - - # initialize normalizers for each backend and fit to target image - normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') - normalizer.fit(target) - - torch_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch') - torch_normalizer.fit(T(target)) - - tf_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='tensorflow') - tf_normalizer.fit(T(target)) - - # transform - result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) - result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True) - result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True) - - # convert to numpy and set dtype - result_numpy = result_numpy.astype("float32") - result_torch = result_torch.numpy().astype("float32") - result_tf = result_tf.numpy().astype("float32") - - # assess whether the normalized images are identical across backends - np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) - np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) diff --git a/tests/test_tf.py b/tests/test_tf.py new file mode 100644 index 0000000..6098e37 --- /dev/null +++ b/tests/test_tf.py @@ -0,0 +1,79 @@ +import os +import cv2 +import torchstain +import torchstain.tf +import tensorflow as tf +import time +from skimage.metrics import structural_similarity as ssim +import numpy as np + +def test_cov(): + x = np.random.randn(10, 10) + cov_np = np.cov(x) + cov_t = torchstain.tf.utils.cov(x) + + np.testing.assert_almost_equal(cov_np, cov_t.numpy()) + +def test_percentile(): + x = np.random.randn(10, 10) + p = 20 + p_np = np.percentile(x, p, interpolation='nearest') + p_t = torchstain.tf.utils.percentile(x, p) + + np.testing.assert_almost_equal(p_np, p_t) + +def test_macenko_tf(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = lambda x: tf.convert_to_tensor(np.moveaxis(x, -1, 0).astype("float32")) # * 255 + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') + normalizer.fit(target) + + tf_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='tensorflow') + tf_normalizer.fit(T(target)) + + # transform + result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) + result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") + result_tf = result_tf.numpy().astype("float32") + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) + +def test_reinhard_tf(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = lambda x: tf.convert_to_tensor(x, dtype=tf.float32) + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy') + normalizer.fit(target) + + tf_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='tensorflow') + tf_normalizer.fit(T(target)) + + # transform + result_numpy = normalizer.normalize(I=to_transform) + result_tf = tf_normalizer.normalize(I=t_to_transform) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") + result_tf = result_tf.numpy().astype("float32") + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) diff --git a/tests/test_tf_normalizer.py b/tests/test_tf_normalizer.py deleted file mode 100644 index c67d8e9..0000000 --- a/tests/test_tf_normalizer.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import cv2 -import torchstain -import tensorflow as tf -import time -from skimage.metrics import structural_similarity as ssim -import numpy as np - -def test_normalize_tf(): - size = 1024 - curr_file_path = os.path.dirname(os.path.realpath(__file__)) - target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) - to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) - - # setup preprocessing and preprocess image to be normalized - T = lambda x: tf.convert_to_tensor(np.moveaxis(x, -1, 0).astype("float32")) # * 255 - t_to_transform = T(to_transform) - - # initialize normalizers for each backend and fit to target image - normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') - normalizer.fit(target) - - tf_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='tensorflow') - tf_normalizer.fit(T(target)) - - # transform - result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) - result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True) - - # convert to numpy and set dtype - result_numpy = result_numpy.astype("float32") - result_tf = result_tf.numpy().astype("float32") - - # assess whether the normalized images are identical across backends - np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 0000000..0a7e99a --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,90 @@ +import os +import cv2 +import torchstain +import torchstain.torch +import torch +import torchvision +import time +import numpy as np +from torchvision import transforms +from skimage.metrics import structural_similarity as ssim + +def setup_function(fn): + print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__) + +def test_cov(): + x = np.random.randn(10, 10) + cov_np = np.cov(x) + cov_t = torchstain.torch.utils.cov(torch.tensor(x)) + + np.testing.assert_almost_equal(cov_np, cov_t.numpy()) + +def test_percentile(): + x = np.random.randn(10, 10) + p = 20 + p_np = np.percentile(x, p, interpolation='nearest') + p_t = torchstain.torch.utils.percentile(torch.tensor(x), p) + + np.testing.assert_almost_equal(p_np, p_t) + +def test_macenko_torch(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 255) + ]) + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') + normalizer.fit(target) + + torch_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch') + torch_normalizer.fit(T(target)) + + # transform + result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) + result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") + result_torch = result_torch.numpy().astype("float32") + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) + +def test_reinhard_torch(): + size = 1024 + curr_file_path = os.path.dirname(os.path.realpath(__file__)) + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) + + # setup preprocessing and preprocess image to be normalized + T = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 255) + ]) + t_to_transform = T(to_transform) + + # initialize normalizers for each backend and fit to target image + normalizer = torchstain.normalizers.ReinhardNormalizer(backend='numpy') + normalizer.fit(target) + + torch_normalizer = torchstain.normalizers.ReinhardNormalizer(backend='torch') + torch_normalizer.fit(T(target)) + + # transform + result_numpy = normalizer.normalize(I=to_transform) + result_torch = torch_normalizer.normalize(I=t_to_transform) + + # convert to numpy and set dtype + result_numpy = result_numpy.astype("float32") + result_torch = result_torch.numpy().astype("float32") + + # assess whether the normalized images are identical across backends + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) diff --git a/tests/test_torch_normalizer.py b/tests/test_torch_normalizer.py deleted file mode 100644 index aa493fc..0000000 --- a/tests/test_torch_normalizer.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import cv2 -import torchstain -import torch -from torchvision import transforms -import time -from skimage.metrics import structural_similarity as ssim -import numpy as np - -def test_normalize_torch(): - size = 1024 - curr_file_path = os.path.dirname(os.path.realpath(__file__)) - target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) - to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) - - # setup preprocessing and preprocess image to be normalized - T = transforms.Compose([ - transforms.ToTensor(), - transforms.Lambda(lambda x: x*255) - ]) - t_to_transform = T(to_transform) - - # initialize normalizers for each backend and fit to target image - normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') - normalizer.fit(target) - - torch_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch') - torch_normalizer.fit(T(target)) - - # transform - result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) - result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True) - - # convert to numpy and set dtype - result_numpy = result_numpy.astype("float32") - result_torch = result_torch.numpy().astype("float32") - - # assess whether the normalized images are identical across backends - np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 266215e..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torchstain -import torchstain.torch -import numpy as np - -def test_cov(): - x = np.random.randn(10, 10) - cov_np = np.cov(x) - cov_t = torchstain.torch.utils.cov(torch.tensor(x)) - - np.testing.assert_almost_equal(cov_np, cov_t.numpy()) - -def test_percentile(): - x = np.random.randn(10, 10) - p = 20 - p_np = np.percentile(x, p, interpolation='nearest') - p_t = torchstain.torch.utils.percentile(torch.tensor(x), p) - - np.testing.assert_almost_equal(p_np, p_t) diff --git a/torchstain/__init__.py b/torchstain/__init__.py index 980e3a6..dab8618 100644 --- a/torchstain/__init__.py +++ b/torchstain/__init__.py @@ -1,3 +1,3 @@ -__version__ = '1.1.0' +__version__ = '1.2.0' from torchstain.base import normalizers \ No newline at end of file diff --git a/torchstain/base/normalizers/__init__.py b/torchstain/base/normalizers/__init__.py index aceb595..a8f8bb7 100644 --- a/torchstain/base/normalizers/__init__.py +++ b/torchstain/base/normalizers/__init__.py @@ -1,2 +1,3 @@ from .he_normalizer import HENormalizer from .macenko import MacenkoNormalizer +from .reinhard import ReinhardNormalizer diff --git a/torchstain/base/normalizers/reinhard.py b/torchstain/base/normalizers/reinhard.py new file mode 100644 index 0000000..f96f087 --- /dev/null +++ b/torchstain/base/normalizers/reinhard.py @@ -0,0 +1,12 @@ +def ReinhardNormalizer(backend='numpy'): + if backend == 'numpy': + from torchstain.numpy.normalizers import NumpyReinhardNormalizer + return NumpyReinhardNormalizer() + elif backend == "torch": + from torchstain.torch.normalizers import TorchReinhardNormalizer + return TorchReinhardNormalizer() + elif backend == "tensorflow": + from torchstain.tf.normalizers import TensorFlowReinhardNormalizer + return TensorFlowReinhardNormalizer() + else: + raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/numpy/__init__.py b/torchstain/numpy/__init__.py index e69de29..113e6f1 100644 --- a/torchstain/numpy/__init__.py +++ b/torchstain/numpy/__init__.py @@ -0,0 +1 @@ +from torchstain.numpy import normalizers, utils diff --git a/torchstain/numpy/normalizers/__init__.py b/torchstain/numpy/normalizers/__init__.py index 2220c20..d453cf1 100644 --- a/torchstain/numpy/normalizers/__init__.py +++ b/torchstain/numpy/normalizers/__init__.py @@ -1 +1,2 @@ -from .macenko import NumpyMacenkoNormalizer \ No newline at end of file +from .macenko import NumpyMacenkoNormalizer +from .reinhard import NumpyReinhardNormalizer \ No newline at end of file diff --git a/torchstain/numpy/normalizers/reinhard.py b/torchstain/numpy/normalizers/reinhard.py new file mode 100644 index 0000000..fbb8d13 --- /dev/null +++ b/torchstain/numpy/normalizers/reinhard.py @@ -0,0 +1,55 @@ +import numpy as np +from torchstain.base.normalizers import HENormalizer +from torchstain.numpy.utils.rgb2lab import rgb2lab +from torchstain.numpy.utils.lab2rgb import lab2rgb +from torchstain.numpy.utils.split import csplit, cmerge, lab_split, lab_merge +from torchstain.numpy.utils.stats import get_mean_std, standardize + +""" +Source code adapted from: +https://github.com/DigitalSlideArchive/HistomicsTK/blob/master/histomicstk/preprocessing/color_normalization/reinhard.py +https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py +""" +class NumpyReinhardNormalizer(HENormalizer): + def __init__(self): + super().__init__() + self.target_mus = None + self.target_stds = None + + def fit(self, target): + # normalize + target = target.astype("float32") / 255 + + # convert to LAB + lab = rgb2lab(target) + + # get summary statistics + stack_ = np.array([get_mean_std(x) for x in lab_split(lab)]) + self.target_means = stack_[:, 0] + self.target_stds = stack_[:, 1] + + def normalize(self, I): + # normalize + I = I.astype("float32") / 255 + + # convert to LAB + lab = rgb2lab(I) + labs = lab_split(lab) + + # get summary statistics from LAB + stack_ = np.array([get_mean_std(x) for x in labs]) + mus = stack_[:, 0] + stds = stack_[:, 1] + + # standardize intensities channel-wise and normalize using target mus and stds + result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ + in zip(labs, mus, stds, self.target_means, self.target_stds)] + + # rebuild LAB + lab = lab_merge(*result) + + # convert back to RGB from LAB + lab = lab2rgb(lab) + + # rescale to [0, 255] uint8 + return (lab * 255).astype("uint8") diff --git a/torchstain/numpy/utils/__init__.py b/torchstain/numpy/utils/__init__.py new file mode 100644 index 0000000..f440077 --- /dev/null +++ b/torchstain/numpy/utils/__init__.py @@ -0,0 +1,4 @@ +from torchstain.numpy.utils.rgb2lab import * +from torchstain.numpy.utils.lab2rgb import * +from torchstain.numpy.utils.split import * +from torchstain.numpy.utils.stats import * diff --git a/torchstain/numpy/utils/lab2rgb.py b/torchstain/numpy/utils/lab2rgb.py new file mode 100644 index 0000000..ddcd6c0 --- /dev/null +++ b/torchstain/numpy/utils/lab2rgb.py @@ -0,0 +1,38 @@ +import numpy as np +from torchstain.numpy.utils.rgb2lab import _rgb2xyz + +_xyz2rgb = np.linalg.inv(_rgb2xyz) + +""" +Implementation is based on: +https://github.com/scikit-image/scikit-image/blob/00177e14097237ef20ed3141ed454bc81b308f82/skimage/color/colorconv.py#L704 +""" +def lab2rgb(lab): + lab = lab.astype("float32") + # first rescale back from OpenCV format + lab[..., 0] /= 2.55 + lab[..., 1] -= 128 + lab[..., 2] -= 128 + + # convert LAB -> XYZ color domain + L, a, b = lab[..., 0], lab[..., 1], lab[..., 2] + y = (L + 16.) / 116. + x = (a / 500.) + y + z = y - (b / 200.) + + out = np.stack([x, y, z], axis=-1) + + mask = out > 0.2068966 + out[mask] = np.power(out[mask], 3.) + out[~mask] = (out[~mask] - 16.0 / 116.) / 7.787 + + # rescale to the reference white (illuminant) + out *= np.array((0.95047, 1., 1.08883), dtype=out.dtype) + + # convert XYZ -> RGB color domain + arr = out.copy() + arr = np.dot(arr, _xyz2rgb.T) + mask = arr > 0.0031308 + arr[mask] = 1.055 * np.power(arr[mask], 1 / 2.4) - 0.055 + arr[~mask] *= 12.92 + return np.clip(arr, 0, 1) diff --git a/torchstain/numpy/utils/rgb2lab.py b/torchstain/numpy/utils/rgb2lab.py new file mode 100644 index 0000000..e0edbed --- /dev/null +++ b/torchstain/numpy/utils/rgb2lab.py @@ -0,0 +1,45 @@ +import numpy as np + +# constant conversion matrices between color spaces: https://gist.github.com/bikz05/6fd21c812ef6ebac66e1 +_rgb2xyz = np.array([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + +""" +Implementation adapted from: +https://gist.github.com/bikz05/6fd21c812ef6ebac66e1 +https://github.com/scikit-image/scikit-image/blob/00177e14097237ef20ed3141ed454bc81b308f82/skimage/color/colorconv.py#L704 +""" +def rgb2lab(rgb): + rgb = rgb.astype("float32") + + # convert rgb -> xyz color domain + arr = rgb.copy() + mask = arr > 0.04045 + arr[mask] = np.power((arr[mask] + 0.055) / 1.055, 2.4) + arr[~mask] /= 12.92 + xyz = np.dot(arr, _rgb2xyz.T.astype(arr.dtype)) + + # scale by CIE XYZ tristimulus values of the reference white point + arr = xyz.copy() + arr = arr / np.asarray((0.95047, 1., 1.08883), dtype=xyz.dtype) + + # Nonlinear distortion and linear transformation + mask = arr > 0.008856 + arr[mask] = np.cbrt(arr[mask]) + arr[~mask] = 7.787 * arr[~mask] + 16. / 116. + + x, y, z = arr[..., 0], arr[..., 1], arr[..., 2] + + # Vector scaling + L = (116. * y) - 16. + a = 500.0 * (x - y) + b = 200.0 * (y - z) + + # OpenCV format + L *= 2.55 + a += 128 + b += 128 + + # finally, get LAB color domain + return np.concatenate([x[..., np.newaxis] for x in [L, a, b]], axis=-1) diff --git a/torchstain/numpy/utils/split.py b/torchstain/numpy/utils/split.py new file mode 100644 index 0000000..978e9a2 --- /dev/null +++ b/torchstain/numpy/utils/split.py @@ -0,0 +1,16 @@ +import numpy as np +from torchstain.numpy.utils.rgb2lab import rgb2lab + +def csplit(I): + return [I[..., i] for i in range(I.shape[-1])] + +def cmerge(I1, I2, I3): + return np.stack([I1, I2, I3], axis=-1) + +def lab_split(I): + I = I.astype("float32") + I1, I2, I3 = csplit(I) + return I1 / 2.55, I2 - 128, I3 - 128 + +def lab_merge(I1, I2, I3): + return cmerge(I1 * 2.55, I2 + 128, I3 + 128) diff --git a/torchstain/numpy/utils/stats.py b/torchstain/numpy/utils/stats.py new file mode 100644 index 0000000..b43ebe1 --- /dev/null +++ b/torchstain/numpy/utils/stats.py @@ -0,0 +1,7 @@ +import numpy as np + +def get_mean_std(I): + return np.mean(I), np.std(I) + +def standardize(x, mu, std): + return (x - mu) / std diff --git a/torchstain/tf/normalizers/__init__.py b/torchstain/tf/normalizers/__init__.py index f835e4d..fb0718f 100644 --- a/torchstain/tf/normalizers/__init__.py +++ b/torchstain/tf/normalizers/__init__.py @@ -1 +1,2 @@ -from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer \ No newline at end of file +from torchstain.tf.normalizers.macenko import TensorFlowMacenkoNormalizer +from torchstain.tf.normalizers.reinhard import TensorFlowReinhardNormalizer diff --git a/torchstain/tf/normalizers/macenko.py b/torchstain/tf/normalizers/macenko.py index e9a69ce..bf21666 100644 --- a/torchstain/tf/normalizers/macenko.py +++ b/torchstain/tf/normalizers/macenko.py @@ -1,6 +1,6 @@ import tensorflow as tf from torchstain.base.normalizers.he_normalizer import HENormalizer -from torchstain.tf.utils import cov_tf, percentile_tf, solveLS +from torchstain.tf.utils import cov, percentile, solveLS import numpy as np import tensorflow.keras.backend as K @@ -35,8 +35,8 @@ def __find_HE(self, ODhat, eigvecs, alpha): That = tf.linalg.matmul(ODhat, eigvecs) phi = tf.math.atan2(That[:, 1], That[:, 0]) - minPhi = percentile_tf(phi, alpha) - maxPhi = percentile_tf(phi, 100 - alpha) + minPhi = percentile(phi, alpha) + maxPhi = percentile(phi, 100 - alpha) vMin = tf.matmul(eigvecs, tf.expand_dims(tf.stack((tf.math.cos(minPhi), tf.math.sin(minPhi))), axis=-1)) vMax = tf.matmul(eigvecs, tf.expand_dims(tf.stack((tf.math.cos(maxPhi), tf.math.sin(maxPhi))), axis=-1)) @@ -58,13 +58,13 @@ def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors - _, eigvecs = tf.linalg.eigh(cov_tf(tf.transpose(ODhat))) + _, eigvecs = tf.linalg.eigh(cov(tf.transpose(ODhat))) eigvecs = eigvecs[:, 1:3] HE = self.__find_HE(ODhat, eigvecs, alpha) C = self.__find_concentration(OD, HE) - maxC = tf.stack([percentile_tf(C[0, :], 99), percentile_tf(C[1, :], 99)]) + maxC = tf.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) return HE, C, maxC diff --git a/torchstain/tf/normalizers/reinhard.py b/torchstain/tf/normalizers/reinhard.py new file mode 100644 index 0000000..8a2b601 --- /dev/null +++ b/torchstain/tf/normalizers/reinhard.py @@ -0,0 +1,55 @@ +import tensorflow as tf +from torchstain.base.normalizers import HENormalizer +from torchstain.tf.utils.rgb2lab import rgb2lab +from torchstain.tf.utils.lab2rgb import lab2rgb +from torchstain.tf.utils.split import csplit, cmerge, lab_split, lab_merge +from torchstain.tf.utils.stats import get_mean_std, standardize + +""" +Source code adapted from: +https://github.com/DigitalSlideArchive/HistomicsTK/blob/master/histomicstk/preprocessing/color_normalization/reinhard.py +https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py +""" +class TensorFlowReinhardNormalizer(HENormalizer): + def __init__(self): + super().__init__() + self.target_mus = None + self.target_stds = None + + def fit(self, target): + # normalize + target = tf.cast(target, tf.float32) / 255 + + # convert to LAB + lab = rgb2lab(target) + + # get summary statistics + stack_ = tf.convert_to_tensor([get_mean_std(x) for x in lab_split(lab)]) + self.target_means = stack_[:, 0] + self.target_stds = stack_[:, 1] + + def normalize(self, I): + # normalize + I = tf.cast(I, tf.float32) / 255 + + # convert to LAB + lab = rgb2lab(I) + labs = lab_split(lab) + + # get summary statistics from LAB + stack_ = tf.convert_to_tensor([get_mean_std(x) for x in labs]) + mus = stack_[:, 0] + stds = stack_[:, 1] + + # standardize intensities channel-wise and normalize using target mus and stds + result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ + in zip(labs, mus, stds, self.target_means, self.target_stds)] + + # rebuild LAB + lab = lab_merge(*result) + + # convert back to RGB from LAB + lab = lab2rgb(lab) + + # rescale to [0, 255] uint8 + return tf.cast(lab * 255, tf.uint8) diff --git a/torchstain/tf/utils/__init__.py b/torchstain/tf/utils/__init__.py index 1c65f31..3e2035c 100644 --- a/torchstain/tf/utils/__init__.py +++ b/torchstain/tf/utils/__init__.py @@ -1,3 +1,7 @@ -from torchstain.tf.utils.cov import cov_tf -from torchstain.tf.utils.percentile import percentile_tf -from torchstain.tf.normalizers.solveLS import solveLS +from torchstain.tf.utils.cov import cov +from torchstain.tf.utils.percentile import percentile +from torchstain.tf.utils.solveLS import solveLS +from torchstain.tf.utils.stats import * +from torchstain.tf.utils.split import * +from torchstain.tf.utils.rgb2lab import * +from torchstain.tf.utils.lab2rgb import * diff --git a/torchstain/tf/utils/cov.py b/torchstain/tf/utils/cov.py index c0ce049..3429d90 100644 --- a/torchstain/tf/utils/cov.py +++ b/torchstain/tf/utils/cov.py @@ -1,6 +1,6 @@ import tensorflow as tf -def cov_tf(x): +def cov(x): """ https://en.wikipedia.org/wiki/Covariance_matrix """ diff --git a/torchstain/tf/utils/lab2rgb.py b/torchstain/tf/utils/lab2rgb.py new file mode 100644 index 0000000..b1d2638 --- /dev/null +++ b/torchstain/tf/utils/lab2rgb.py @@ -0,0 +1,36 @@ +import tensorflow as tf +from torchstain.tf.utils.rgb2lab import _rgb2xyz, _white + +_xyz2rgb = tf.linalg.inv(_rgb2xyz) + +def lab2rgb(lab): + lab = tf.cast(lab, tf.float32) + + # rescale back from OpenCV format and extract LAB channel + L, a, b = lab[..., 0] / 2.55, lab[..., 1] - 128, lab[..., 2] - 128 + + # vector scaling to produce X, Y, Z + y = (L + 16.) / 116. + x = (a / 500.) + y + z = y - (b / 200.) + + # merge back to get reconstructed XYZ color image + out = tf.stack([x, y, z], axis=-1) + + # apply boolean transforms + mask = out > 0.2068966 + not_mask = tf.math.logical_not(mask) + out = tf.tensor_scatter_nd_update(out, tf.where(mask), tf.pow(tf.boolean_mask(out, mask), 3)) + out = tf.tensor_scatter_nd_update(out, tf.where(not_mask), (tf.boolean_mask(out, not_mask) - 16 / 116) / 7.787) + + # rescale to the reference white (illuminant) + out = out * tf.cast(_white, out.dtype) + + # convert XYZ -> RGB color domain + arr = tf.identity(out) + arr = arr @ tf.transpose(_xyz2rgb) + mask = arr > 0.0031308 + not_mask = tf.math.logical_not(mask) + arr = tf.tensor_scatter_nd_update(arr, tf.where(mask), 1.055 * tf.pow(tf.boolean_mask(arr, mask), 1 / 2.4) - 0.055) + arr = tf.tensor_scatter_nd_update(arr, tf.where(not_mask), tf.boolean_mask(out, not_mask) * 12.92) + return tf.clip_by_value(arr, 0, 1) diff --git a/torchstain/tf/utils/percentile.py b/torchstain/tf/utils/percentile.py index c6697dc..32f7c67 100644 --- a/torchstain/tf/utils/percentile.py +++ b/torchstain/tf/utils/percentile.py @@ -1,12 +1,11 @@ from typing import Union import tensorflow as tf -def percentile_tf(t: tf.Tensor, q: float) -> Union[int, float]: +def percentile(t: tf.Tensor, q: float) -> Union[int, float]: """ Return the ``q``-th percentile of the flattened input tensor's data. CAUTION: - * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. * Values are not interpolated, which corresponds to ``numpy.percentile(..., interpolation="nearest")``. @@ -14,5 +13,5 @@ def percentile_tf(t: tf.Tensor, q: float) -> Union[int, float]: :param q: Percentile to compute, which must be between 0 and 100 inclusive. :return: Resulting value (scalar). """ - k = 1 + tf.math.round(.01 * tf.cast(q, tf.float32) * (tf.cast(tf.size(t), tf.float32) - 1)) - return tf.sort(tf.reshape(t, [-1]))[tf.cast(k, tf.int32)] + k = 1 + tf.math.round(.01 * tf.cast(q, tf.float32) * (tf.cast(tf.math.reduce_prod(tf.size(t)), tf.float32) - 1)) + return tf.sort(tf.reshape(t, [-1]))[tf.cast(k - 1, tf.int32)] diff --git a/torchstain/tf/utils/rgb2lab.py b/torchstain/tf/utils/rgb2lab.py new file mode 100644 index 0000000..d679a63 --- /dev/null +++ b/torchstain/tf/utils/rgb2lab.py @@ -0,0 +1,45 @@ +import tensorflow as tf + +# constant conversion matrices between color spaces: https://gist.github.com/bikz05/6fd21c812ef6ebac66e1 +_rgb2xyz = tf.constant([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + +_white = tf.constant([0.95047, 1., 1.08883]) + +def rgb2lab(rgb): + arr = tf.cast(rgb, tf.float32) + + # convert rgb -> xyz color domain + mask = arr > 0.04045 + not_mask = tf.math.logical_not(mask) + arr = tf.tensor_scatter_nd_update(arr, tf.where(mask), tf.math.pow((tf.boolean_mask(arr, mask) + 0.055) / 1.055, 2.4)) + arr = tf.tensor_scatter_nd_update(arr, tf.where(not_mask), tf.boolean_mask(arr, not_mask) / 12.92) + + xyz = arr @ tf.cast(tf.transpose(_rgb2xyz), arr.dtype) + + # scale by CIE XYZ tristimulus values of the reference white point + arr = tf.identity(xyz) + arr = arr / tf.cast(_white, xyz.dtype) + + # nonlinear distortion and linear transformation + mask = arr > 0.008856 + not_mask = tf.math.logical_not(mask) + arr = tf.tensor_scatter_nd_update(arr, tf.where(mask), tf.math.pow(tf.boolean_mask(arr, mask), 1.0 / 3.0)) + arr = tf.tensor_scatter_nd_update(arr, tf.where(not_mask), 7.787 * tf.boolean_mask(arr, not_mask) + 16 / 116) + + # get each channel as individual tensors + x, y, z = arr[..., 0], arr[..., 1], arr[..., 2] + + # vector scaling + L = (116. * y) - 16. + a = 500.0 * (x - y) + b = 200.0 * (y - z) + + # OpenCV format + L *= 2.55 + a += 128 + b += 128 + + # finally, get LAB color domain + return tf.stack([L, a, b], axis=-1) diff --git a/torchstain/tf/normalizers/solveLS.py b/torchstain/tf/utils/solveLS.py similarity index 100% rename from torchstain/tf/normalizers/solveLS.py rename to torchstain/tf/utils/solveLS.py diff --git a/torchstain/tf/utils/split.py b/torchstain/tf/utils/split.py new file mode 100644 index 0000000..5717b87 --- /dev/null +++ b/torchstain/tf/utils/split.py @@ -0,0 +1,15 @@ +import tensorflow as tf + +def csplit(I): + return [I[..., i] for i in range(I.shape[-1])] + +def cmerge(I1, I2, I3): + return tf.stack([I1, I2, I3], axis=-1) + +def lab_split(I): + I = tf.cast(I, tf.float32) + I1, I2, I3 = csplit(I) + return I1 / 2.55, I2 - 128, I3 - 128 + +def lab_merge(I1, I2, I3): + return cmerge(I1 * 2.55, I2 + 128, I3 + 128) diff --git a/torchstain/tf/utils/stats.py b/torchstain/tf/utils/stats.py new file mode 100644 index 0000000..962afd5 --- /dev/null +++ b/torchstain/tf/utils/stats.py @@ -0,0 +1,7 @@ +import tensorflow as tf + +def get_mean_std(I): + return tf.math.reduce_mean(I), tf.math.reduce_std(I) + +def standardize(x, mu, std): + return (x - mu) / std diff --git a/torchstain/torch/normalizers/__init__.py b/torchstain/torch/normalizers/__init__.py index febcd90..c78c273 100644 --- a/torchstain/torch/normalizers/__init__.py +++ b/torchstain/torch/normalizers/__init__.py @@ -1 +1,2 @@ from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer +from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer diff --git a/torchstain/torch/normalizers/reinhard.py b/torchstain/torch/normalizers/reinhard.py new file mode 100644 index 0000000..0970764 --- /dev/null +++ b/torchstain/torch/normalizers/reinhard.py @@ -0,0 +1,55 @@ +import torch +from torchstain.base.normalizers import HENormalizer +from torchstain.torch.utils.rgb2lab import rgb2lab +from torchstain.torch.utils.lab2rgb import lab2rgb +from torchstain.torch.utils.split import csplit, cmerge, lab_split, lab_merge +from torchstain.torch.utils.stats import get_mean_std, standardize + +""" +Source code adapted from: +https://github.com/DigitalSlideArchive/HistomicsTK/blob/master/histomicstk/preprocessing/color_normalization/reinhard.py +https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py +""" +class TorchReinhardNormalizer(HENormalizer): + def __init__(self): + super().__init__() + self.target_mus = None + self.target_stds = None + + def fit(self, target): + # normalize + target = target.type(torch.float32) / 255 + + # convert to LAB + lab = rgb2lab(target) + + # get summary statistics + stack_ = torch.tensor([get_mean_std(x) for x in lab_split(lab)]) + self.target_means = stack_[:, 0] + self.target_stds = stack_[:, 1] + + def normalize(self, I): + # normalize + I = I.type(torch.float32) / 255 + + # convert to LAB + lab = rgb2lab(I) + labs = lab_split(lab) + + # get summary statistics from LAB + stack_ = torch.tensor([get_mean_std(x) for x in labs]) + mus = stack_[:, 0] + stds = stack_[:, 1] + + # standardize intensities channel-wise and normalize using target mus and stds + result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \ + in zip(labs, mus, stds, self.target_means, self.target_stds)] + + # rebuild LAB + lab = lab_merge(*result) + + # convert back to RGB from LAB + lab = lab2rgb(lab) + + # rescale to [0, 255] uint8 + return (lab * 255).type(torch.uint8) diff --git a/torchstain/torch/utils/__init__.py b/torchstain/torch/utils/__init__.py index 5e0de3e..4acea5a 100644 --- a/torchstain/torch/utils/__init__.py +++ b/torchstain/torch/utils/__init__.py @@ -1,2 +1,6 @@ from torchstain.torch.utils.cov import cov from torchstain.torch.utils.percentile import percentile +from torchstain.torch.utils.stats import * +from torchstain.torch.utils.split import * +from torchstain.torch.utils.rgb2lab import * +from torchstain.torch.utils.lab2rgb import * diff --git a/torchstain/torch/utils/lab2rgb.py b/torchstain/torch/utils/lab2rgb.py new file mode 100644 index 0000000..250c1c9 --- /dev/null +++ b/torchstain/torch/utils/lab2rgb.py @@ -0,0 +1,35 @@ +import torch +from torchstain.torch.utils.rgb2lab import _rgb2xyz, _white + +_xyz2rgb = torch.linalg.inv(_rgb2xyz) + +def lab2rgb(lab): + lab = lab.type(torch.float32) + + # rescale back from OpenCV format and extract LAB channel + L, a, b = lab[0] / 2.55, lab[1] - 128, lab[2] - 128 + + # vector scaling to produce X, Y, Z + y = (L + 16.) / 116. + x = (a / 500.) + y + z = y - (b / 200.) + + # merge back to get reconstructed XYZ color image + out = torch.stack([x, y, z], axis=0) + + # apply boolean transforms + mask = out > 0.2068966 + not_mask = torch.logical_not(mask) + out.masked_scatter_(mask, torch.pow(torch.masked_select(out, mask), 3)) + out.masked_scatter_(not_mask, (torch.masked_select(out, not_mask) - 16 / 116) / 7.787) + + # rescale to the reference white (illuminant) + out = torch.mul(out, _white.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1)) + + # convert XYZ -> RGB color domain + arr = torch.tensordot(out, torch.t(_xyz2rgb).type(out.dtype), dims=([0], [0])) + mask = arr > 0.0031308 + not_mask = torch.logical_not(mask) + arr.masked_scatter_(mask, 1.055 * torch.pow(torch.masked_select(arr, mask), 1 / 2.4) - 0.055) + arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) * 12.92) + return torch.clamp(arr, 0, 1) diff --git a/torchstain/torch/utils/percentile.py b/torchstain/torch/utils/percentile.py index 188f1bb..08c28ef 100644 --- a/torchstain/torch/utils/percentile.py +++ b/torchstain/torch/utils/percentile.py @@ -4,9 +4,9 @@ """ Author: https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30 """ -def percentile(t: torch.tensor, q: float) -> Union[int, float]: +def percentile(t: torch.Tensor, q: float) -> Union[int, float]: """ - Return the ``q``-th percentile of the flattened input tensor's data. + Return the ``q``-th percentile of the flattenepip d input tensor's data. CAUTION: * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. diff --git a/torchstain/torch/utils/rgb2lab.py b/torchstain/torch/utils/rgb2lab.py new file mode 100644 index 0000000..3b1aa50 --- /dev/null +++ b/torchstain/torch/utils/rgb2lab.py @@ -0,0 +1,44 @@ +import torch + +# constant conversion matrices between color spaces: https://gist.github.com/bikz05/6fd21c812ef6ebac66e1 +_rgb2xyz = torch.tensor([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + +_white = torch.tensor([0.95047, 1., 1.08883]) + +def rgb2lab(rgb): + arr = rgb.type(torch.float32) + + # convert rgb -> xyz color domain + mask = arr > 0.04045 + not_mask = torch.logical_not(mask) + arr.masked_scatter_(mask, torch.pow((torch.masked_select(arr, mask) + 0.055) / 1.055, 2.4)) + arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) / 12.92) + + xyz = torch.tensordot(torch.t(_rgb2xyz), arr, dims=([0], [0])) + + # scale by CIE XYZ tristimulus values of the reference white point + arr = torch.mul(xyz, 1 / _white.type(xyz.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1)) + + # nonlinear distortion and linear transformation + mask = arr > 0.008856 + not_mask = torch.logical_not(mask) + arr.masked_scatter_(mask, torch.pow(torch.masked_select(arr, mask), 1 / 3)) + arr.masked_scatter_(not_mask, 7.787 * torch.masked_select(arr, not_mask) + 16 / 166) + + # get each channel as individual tensors + x, y, z = arr[0], arr[1], arr[2] + + # vector scaling + L = (116. * y) - 16. + a = 500.0 * (x - y) + b = 200.0 * (y - z) + + # OpenCV format + L *= 2.55 + a += 128 + b += 128 + + # finally, get LAB color domain + return torch.stack([L, a, b], axis=0) diff --git a/torchstain/torch/utils/split.py b/torchstain/torch/utils/split.py new file mode 100644 index 0000000..d6f6fdb --- /dev/null +++ b/torchstain/torch/utils/split.py @@ -0,0 +1,15 @@ +import torch + +def csplit(I): + return [I[i] for i in range(I.shape[0])] + +def cmerge(I1, I2, I3): + return torch.stack([I1, I2, I3], dim=0) + +def lab_split(I): + I = I.type(torch.float32) + I1, I2, I3 = csplit(I) + return I1 / 2.55, I2 - 128, I3 - 128 + +def lab_merge(I1, I2, I3): + return cmerge(I1 * 2.55, I2 + 128, I3 + 128) diff --git a/torchstain/torch/utils/stats.py b/torchstain/torch/utils/stats.py new file mode 100644 index 0000000..0fa45bb --- /dev/null +++ b/torchstain/torch/utils/stats.py @@ -0,0 +1,7 @@ +import torch + +def get_mean_std(I): + return torch.mean(I), torch.std(I) + +def standardize(x, mu, std): + return (x - mu) / std