Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
jobs:
build_wheels:
name: Build release
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04

steps:
- uses: actions/checkout@v3
Expand All @@ -27,7 +27,7 @@ jobs:

upload_pypi:
needs: build_wheels
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04

if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')

Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/tests_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
build:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04

if: startsWith(github.ref, 'refs/tags/v') != true

Expand Down Expand Up @@ -39,7 +39,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ windows-2019, ubuntu-18.04, macos-11 ]
os: [ windows-2019, ubuntu-20.04, macos-11 ]
python-version: [ 3.7, 3.8, 3.9 ]
tf-version: [2.7.0, 2.8.0, 2.9.0]

Expand Down Expand Up @@ -71,14 +71,16 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ windows-2019, ubuntu-18.04, macos-11 ]
os: [ windows-2019, ubuntu-20.04, macos-11 ]
python-version: [ 3.6, 3.7, 3.8, 3.9 ]
pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
exclude:
- python-version: 3.6
pytorch-version: 1.11.0
- python-version: 3.6
pytorch-version: 1.12.0
- python-version: 3.6
pytorch-version: 1.13.0

steps:
- uses: actions/checkout@v1
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests_quick.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
build:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.6
Expand All @@ -33,7 +33,7 @@ jobs:

test-tf:
needs: build
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04

steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='torchstain',
version='1.2.0',
version='1.3.0',
description='Stain normalization tools for histological analysis and computational pathology',
long_description=README,
long_description_content_type='text/markdown',
Expand Down
12 changes: 8 additions & 4 deletions tests/test_color_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import cv2
import os

def test_rgb_to_lab():
def test_rgb_lab():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
img = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# rgb2lab expects data to be float32 in range [0, 1]
img = img / 255

# convert from RGB to LAB and back again to RGB
reconstructed_img = lab2rgb(rgb2lab(img))
val = np.mean(np.abs(reconstructed_img - img))
print("MAE:", val)
assert val < 0.1

# assess if the reconstructed image is similar to the original image
np.testing.assert_almost_equal(np.mean(np.abs(reconstructed_img - img)), 0.0, decimal=4, verbose=True)
14 changes: 6 additions & 8 deletions tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import torchstain
import torchstain.tf
import tensorflow as tf
import time
from skimage.metrics import structural_similarity as ssim
import numpy as np

def test_cov():
Expand Down Expand Up @@ -44,11 +42,11 @@ def test_macenko_tf():
result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_tf = result_tf.numpy().astype("float32")
result_numpy = result_numpy.astype("float32") / 255.
result_tf = result_tf.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)

def test_reinhard_tf():
size = 1024
Expand All @@ -72,8 +70,8 @@ def test_reinhard_tf():
result_tf = tf_normalizer.normalize(I=t_to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_tf = result_tf.numpy().astype("float32")
result_numpy = result_numpy.astype("float32") / 255.
result_tf = result_tf.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)
16 changes: 8 additions & 8 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import torchstain.torch
import torch
import torchvision
import time
import numpy as np
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim


def setup_function(fn):
print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__)
Expand Down Expand Up @@ -52,11 +51,11 @@ def test_macenko_torch():
result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_torch = result_torch.numpy().astype("float32")
result_numpy = result_numpy.astype("float32") / 255.
result_torch = result_torch.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)

def test_reinhard_torch():
size = 1024
Expand All @@ -83,8 +82,9 @@ def test_reinhard_torch():
result_torch = torch_normalizer.normalize(I=t_to_transform)

# convert to numpy and set dtype
result_numpy = result_numpy.astype("float32")
result_torch = result_torch.numpy().astype("float32")
result_numpy = result_numpy.astype("float32") / 255.
result_torch = result_torch.numpy().astype("float32") / 255.


# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)
2 changes: 1 addition & 1 deletion torchstain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '1.2.0'
__version__ = '1.3.0'

from torchstain.base import normalizers
47 changes: 27 additions & 20 deletions torchstain/numpy/utils/lab2rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,39 @@
Implementation is based on:
https://github.com/scikit-image/scikit-image/blob/00177e14097237ef20ed3141ed454bc81b308f82/skimage/color/colorconv.py#L704
"""
def lab2rgb(lab):
lab = lab.astype("float32")
def lab2rgb(lab: np.ndarray) -> np.ndarray:
"""
Convert an array of LAB values to RGB values.

Args:
lab (np.ndarray): An array of shape (..., 3) containing LAB values.

Returns:
np.ndarray: An array of shape (..., 3) containing RGB values.
"""
# first rescale back from OpenCV format
lab[..., 0] /= 2.55
lab[..., 1] -= 128
lab[..., 2] -= 128
lab[..., 1:] -= 128

# convert LAB -> XYZ color domain
L, a, b = lab[..., 0], lab[..., 1], lab[..., 2]
y = (L + 16.) / 116.
x = (a / 500.) + y
z = y - (b / 200.)
y = (lab[..., 0] + 16.) / 116.
x = (lab[..., 1] / 500.) + y
z = y - (lab[..., 2] / 200.)

out = np.stack([x, y, z], axis=-1)
xyz = np.stack([x, y, z], axis=-1)

mask = out > 0.2068966
out[mask] = np.power(out[mask], 3.)
out[~mask] = (out[~mask] - 16.0 / 116.) / 7.787
mask = xyz > 0.2068966
xyz[mask] = np.power(xyz[mask], 3.)
xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787

# rescale to the reference white (illuminant)
out *= np.array((0.95047, 1., 1.08883), dtype=out.dtype)
xyz *= np.array((0.95047, 1., 1.08883), dtype=xyz.dtype)

# convert XYZ -> RGB color domain
arr = out.copy()
arr = np.dot(arr, _xyz2rgb.T)
mask = arr > 0.0031308
arr[mask] = 1.055 * np.power(arr[mask], 1 / 2.4) - 0.055
arr[~mask] *= 12.92
return np.clip(arr, 0, 1)
rgb = np.matmul(xyz, _xyz2rgb.T)

mask = rgb > 0.0031308
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
rgb[~mask] *= 12.92

return np.clip(rgb, 0, 1)
6 changes: 4 additions & 2 deletions torchstain/torch/normalizers/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def __init__(self):
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = torch.tensor([1.9705, 1.0308])
self.deprecated_torch = torch.__version__ < (1,9,0)

# Avoid using deprecated torch.lstsq (since 1.9.0)
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')

def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)
Expand Down Expand Up @@ -50,7 +52,7 @@ def __find_concentration(self, OD, HE):
Y = OD.T

# determine concentrations of the individual stains
if self.deprecated_torch:
if not self.updated_lstsq:
return torch.lstsq(Y, HE)[0][:2]

return torch.linalg.lstsq(HE, Y)[0]
Expand Down