Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion napari_cellseg3d/_tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(self, x):
post_process_transforms=mock_work(),
)
assert isinstance(res, InferenceResult)
assert res.result is not None
assert res.semantic_segmentation is not None


def test_post_processing():
Expand Down
10 changes: 5 additions & 5 deletions napari_cellseg3d/_tests/test_plugin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ def test_inference(make_napari_viewer_proxy, qtbot):

assert len(viewer.layers) == 1

widget.window_infer_box.setChecked(True)
widget.use_window_choice.setChecked(True)
widget.window_overlap_slider.setValue(0)
widget.keep_data_on_cpu_box.setChecked(True)

assert widget.check_ready()

widget.model_choice.setCurrentText("WNet")
widget._restrict_window_size_for_model()
assert widget.window_infer_box.isChecked()
assert widget.use_window_choice.isChecked()
assert widget.window_size_choice.currentText() == "64"

test_model_name = "test"
MODEL_LIST[test_model_name] = TestModel
widget.model_choice.addItem(test_model_name)
widget.model_choice.setCurrentText(test_model_name)

widget.window_infer_box.setChecked(False)
widget.use_window_choice.setChecked(False)
widget.worker_config = widget._set_worker_config()
assert widget.worker_config is not None
assert widget.model_info is not None
Expand All @@ -61,7 +61,7 @@ def test_inference(make_napari_viewer_proxy, qtbot):

res = next(worker.inference())
assert isinstance(res, InferenceResult)
assert res.result.shape == (8, 8, 8)
assert res.semantic_segmentation.shape == (8, 8, 8)
assert res.instance_labels.shape == (8, 8, 8)
widget.on_yield(res)

Expand All @@ -73,7 +73,7 @@ def test_inference(make_napari_viewer_proxy, qtbot):
instance_labels=mock_labels,
crf_results=mock_image,
stats=[volume_stats(mock_labels)],
result=mock_image,
semantic_segmentation=mock_image,
model_name="test",
)
num_layers = len(viewer.layers)
Expand Down
4 changes: 4 additions & 0 deletions napari_cellseg3d/_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def test_sphericities():
sphericity_axes = 0
except ValueError:
sphericity_axes = 0
if sphericity_axes is None:
sphericity_axes = (
0 # errors already handled in function, returns None
)
assert 0 <= sphericity_axes <= 1


Expand Down
47 changes: 31 additions & 16 deletions napari_cellseg3d/code_models/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,47 @@

Implemented using the pydense library available at https://github.com/lucasb-eyer/pydensecrf.
"""
import importlib

from warnings import warn
import numpy as np
from napari.qt.threading import GeneratorWorker

from napari_cellseg3d.config import CRFConfig
from napari_cellseg3d.utils import LOGGER as logger

try:
spec = importlib.util.find_spec("pydensecrf")
CRF_INSTALLED = spec is not None
if not CRF_INSTALLED:
logger.info(
"pydensecrf not installed, CRF post-processing will not be available. "
"Please install by running pip install cellseg3d[crf]"
"This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step."
)
else:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import (
create_pairwise_bilateral,
create_pairwise_gaussian,
unary_from_softmax,
)

CRF_INSTALLED = True
except ImportError:
warn(
"pydensecrf not installed, CRF post-processing will not be available. "
"Please install by running pip install cellseg3d[crf]",
stacklevel=1,
)
CRF_INSTALLED = False

# try:
# import pydensecrf.densecrf as dcrf
# from pydensecrf.utils import (
# create_pairwise_bilateral,
# create_pairwise_gaussian,
# unary_from_softmax,
# )
# CRF_INSTALLED = True
# except (ImportError, ModuleNotFoundError):
# logger.info(
# "pydensecrf not installed, CRF post-processing will not be available. "
# "Please install by running pip install cellseg3d[crf]"
# "This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step."
# )
# CRF_INSTALLED = False
# use importlib instead to check if pydensecrf is installed

import numpy as np
from napari.qt.threading import GeneratorWorker

from napari_cellseg3d.config import CRFConfig
from napari_cellseg3d.utils import LOGGER as logger

__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard"
__credits__ = [
Expand Down
50 changes: 46 additions & 4 deletions napari_cellseg3d/code_models/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def run_method_on_channels_from_params(self, image):
return result.squeeze()

@staticmethod
def sliding_window(volume, func, patch_size=512):
def sliding_window(volume, func, patch_size=512, increment_labels=True):
"""Given a volume of dimensions HxWxD, runs the provided function segmentation on the volume using a sliding window of size patch_size.

If the edge has been reached, the patch size is reduced to fit the remaining space.
Expand All @@ -202,14 +202,17 @@ def sliding_window(volume, func, patch_size=512):
volume (np.array): The volume to segment
func (callable): Function to use for instance segmentation. Should be a partial function with the parameters already set.
patch_size (int): The size of the sliding window.
increment_labels (bool): If True, increments the labels of each patch by the maximum label of the previous patch.

Returns:
np.array: Instance segmentation labels from
"""
result = np.zeros(volume.shape, dtype=np.uint32)
max_label_id = 0
x, y, z = volume.shape
for i in tqdm(range(0, x, patch_size)):
pbar_total = (x // patch_size) * (y // patch_size) * (z // patch_size)
pbar = tqdm(total=pbar_total)
for i in range(0, x, patch_size):
for j in range(0, y, patch_size):
for k in range(0, z, patch_size):
patch = volume[
Expand All @@ -220,13 +223,16 @@ def sliding_window(volume, func, patch_size=512):
patch_result = func(patch)
patch_result = np.array(patch_result)
# make sure labels are unique, only where result is not 0
patch_result[patch_result > 0] += max_label_id
if increment_labels:
patch_result[patch_result > 0] += max_label_id
max_label_id = np.max(patch_result)
result[
i : min(i + patch_size, x),
j : min(j + patch_size, y),
k : min(k + patch_size, z),
] = patch_result
max_label_id = np.max(patch_result)
pbar.update(1)
pbar.close()
return result


Expand Down Expand Up @@ -363,6 +369,42 @@ def binary_watershed(
return np.array(segm)


def clear_large_objects(image, large_label_size=200, use_window=True):
"""Uses watershed to label all obejcts, and removes the ones with a volume larger than the specified threshold.

This is intended for artifact removal, and should not be used for instance segmentation.

Args:
image: array containing the image
large_label_size: size threshold for removal of objects in pixels. E.g. if 10, all objects larger than 10 pixels as a whole will be removed.
use_window: if True, will use a sliding window to perform instance segmentation to avoid memory issues. Default : True

Returns:
array: The image with large objects removed
"""
if use_window:
func = partial(
binary_watershed,
thres_objects=0,
thres_seeding=0,
thres_small=large_label_size,
rem_seed_thres=0,
)
res = InstanceMethod.sliding_window(
image, func, increment_labels=False
)
return np.where(res > 0, 0, image)

labeled = binary_watershed(
image,
thres_objects=0,
thres_seeding=0,
thres_small=large_label_size,
rem_seed_thres=0,
)
return np.where(labeled > 0, 0, image)


def clear_small_objects(image, threshold, is_file_path=False):
"""Calls skimage.remove_small_objects to remove small fragments that might be artifacts.

Expand Down
36 changes: 32 additions & 4 deletions napari_cellseg3d/code_models/worker_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains the :py:class:`~InferenceWorker` class, which is a custom worker to run inference jobs in."""
import platform
import sys
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -27,7 +28,10 @@
# local
from napari_cellseg3d import config, utils
from napari_cellseg3d.code_models.crf import crf_with_config
from napari_cellseg3d.code_models.instance_segmentation import volume_stats
from napari_cellseg3d.code_models.instance_segmentation import (
clear_large_objects,
volume_stats,
)
from napari_cellseg3d.code_models.workers_utils import (
PRETRAINED_WEIGHTS_DIR,
InferenceResult,
Expand All @@ -36,6 +40,7 @@
QuantileNormalization,
RemapTensor,
Threshold,
TqdmToLogSignal,
WeightsDownloader,
)

Expand Down Expand Up @@ -96,6 +101,7 @@ def __init__(
super().__init__(self.inference)
self._signals = LogSignal() # add custom signals
self.log_signal = self._signals.log_signal
self.log_w_replace_signal = self._signals.log_w_replace_signal
self.warn_signal = self._signals.warn_signal
self.error_signal = self._signals.error_signal

Expand Down Expand Up @@ -127,6 +133,14 @@ def log(self, text):
"""
self.log_signal.emit(text)

def log_w_replacement(self, text):
"""Sends a signal that ``text`` should be logged, replacing the last line.

Args:
text (str): text to logged
"""
self.log_w_replace_signal.emit(text)

def warn(self, warning):
"""Sends a warning to main thread."""
self.warn_signal.emit(warning)
Expand Down Expand Up @@ -383,11 +397,14 @@ def model_output_wrapper(inputs):
)
return result
##########################################

return post_process_transforms(result)

model.eval()
with torch.no_grad():
### Redirect tqdm pbar to logger
old_stdout = sys.stderr
sys.stderr = TqdmToLogSignal(self.log_w_replacement)
###
outputs = sliding_window_inference(
inputs,
roi_size=window_size,
Expand All @@ -400,6 +417,8 @@ def model_output_wrapper(inputs):
sigma_scale=0.01,
progress=True,
)
###
sys.stderr = old_stdout
except Exception as e:
logger.exception(e)
logger.debug("failed to run sliding window inference")
Expand Down Expand Up @@ -483,7 +502,7 @@ def create_inference_result(
instance_labels=instance_labels,
crf_results=crf_results,
stats=stats,
result=semantic_labels,
semantic_segmentation=semantic_labels,
model_name=self.config.model_info.name,
)

Expand Down Expand Up @@ -511,8 +530,16 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1):
raise ValueError(
"An ID should be provided when running from a file"
)

# old_stderr = sys.stderr
# sys.stderr = TqdmToLogSignal(self.log_w_replacement)
if self.config.post_process_config.instance.enabled:
if self.config.post_process_config.artifact_removal:
self.log("Removing artifacts...")
semantic_labels = clear_large_objects(
semantic_labels,
self.config.post_process_config.artifact_removal_size,
)

instance_labels = self.instance_seg(
semantic_labels,
i + 1,
Expand All @@ -521,6 +548,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1):
else:
instance_labels = None
stats = None
# sys.stderr = old_stderr
return instance_labels, stats

def save_image(
Expand Down
31 changes: 21 additions & 10 deletions napari_cellseg3d/code_models/worker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,24 @@
)

logger = utils.LOGGER
VERBOSE_SCHEDULER = True
logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}")

try:
import wandb
from wandb import (
init,
)

# used to check if wandb is installed, otherwise not used this way
WANDB_INSTALLED = True
except ImportError:
logger.warning(
except (ImportError, ModuleNotFoundError):
logger.info(
"wandb not installed, wandb config will not be taken into account",
stacklevel=1,
)
WANDB_INSTALLED = False

VERBOSE_SCHEDULER = True
logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}")

"""
Writing something to log messages from outside the main thread needs specific care,
Following the instructions in the guides below to have a worker with custom signals,
Expand Down Expand Up @@ -1109,11 +1113,18 @@ def train(
if WANDB_INSTALLED:
config_dict = self.config.__dict__
logger.debug(f"wandb config : {config_dict}")
wandb.init(
config=config_dict,
project="CellSeg3D",
mode=self.wandb_config.mode,
)
try:
wandb.init(
config=config_dict,
project="CellSeg3D",
mode=self.wandb_config.mode,
)
except AttributeError:
logger.warning(
"Could not initialize wandb."
"This might be due to running napari in a folder where there is a directory named 'wandb'."
"Aborting, please run napari in a different folder or install wandb. Sorry for the inconvenience."
)

if deterministic_config.enabled:
set_determinism(
Expand Down
Loading