From 1836d6718cad4646f56edcfd964c1ffdd0d7d72a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 28 Nov 2023 12:24:35 +0100 Subject: [PATCH 01/10] WIP add artifact removal --- .../_tests/test_plugin_inference.py | 6 +-- .../code_models/instance_segmentation.py | 20 ++++++++++ .../code_plugins/plugin_model_inference.py | 40 +++++++++++++++---- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index 9f004b17..e6d4441c 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -25,7 +25,7 @@ def test_inference(make_napari_viewer_proxy, qtbot): assert len(viewer.layers) == 1 - widget.window_infer_box.setChecked(True) + widget.use_window_choice.setChecked(True) widget.window_overlap_slider.setValue(0) widget.keep_data_on_cpu_box.setChecked(True) @@ -33,7 +33,7 @@ def test_inference(make_napari_viewer_proxy, qtbot): widget.model_choice.setCurrentText("WNet") widget._restrict_window_size_for_model() - assert widget.window_infer_box.isChecked() + assert widget.use_window_choice.isChecked() assert widget.window_size_choice.currentText() == "64" test_model_name = "test" @@ -41,7 +41,7 @@ def test_inference(make_napari_viewer_proxy, qtbot): widget.model_choice.addItem(test_model_name) widget.model_choice.setCurrentText(test_model_name) - widget.window_infer_box.setChecked(False) + widget.use_window_choice.setChecked(False) widget.worker_config = widget._set_worker_config() assert widget.worker_config is not None assert widget.model_info is not None diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 044b0e7b..929380b5 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -363,6 +363,26 @@ def binary_watershed( return np.array(segm) +def clear_large_objects(image, large_label_size=200): + """Uses watershed to label all obejcts, and removes the ones with a volume larger than the specified threshold. + + Args: + image: array containing the image + large_label_size: size threshold for removal of objects in pixels. E.g. if 10, all objects larger than 10 pixels as a whole will be removed. + + Returns: + array: The image with large objects removed + """ + labeled = binary_watershed( + image, + thres_objects=0, + thres_seeding=0, + thres_small=large_label_size, + rem_seed_thres=0, + ) + return np.where(labeled > 0, 0, image) + + def clear_small_objects(image, threshold, is_file_path=False): """Calls skimage.remove_small_objects to remove small fragments that might be artifacts. diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index faf8fc49..f1a19dfa 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -150,8 +150,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): parent=self, ) - self.window_infer_box = ui.CheckBox("Use window inference") - self.window_infer_box.toggled.connect(self._toggle_display_window_size) + self.use_window_choice = ui.CheckBox("Use window inference") + self.use_window_choice.toggled.connect( + self._toggle_display_window_size + ) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] self._default_window_size = sizes_window.index("64") @@ -187,6 +189,21 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.window_overlap_slider.container, ], ) + ################## + ################## + # auto-artifact removal widgets + self.attempt_artifact_removal_box = ui.CheckBox( + "Attempt artifact removal", + func=self._toggle_display_artifact_size_thresh, + parent=self, + ) + self.artifact_removal_size = ui.IntIncrementCounter( + lower=1, + upper=10000, + default=500, + text_label="Remove larger than :", + step=100, + ) ################## ################## @@ -261,7 +278,7 @@ def _set_tooltips(self): self.thresholding_checkbox.setToolTip(thresh_desc) self.thresholding_slider.tooltips = thresh_desc - self.window_infer_box.setToolTip( + self.use_window_choice.setToolTip( "Sliding window inference runs the model on parts of the image" "\nrather than the whole image, to reduce memory requirements." "\nUse this if you have large images." @@ -304,11 +321,11 @@ def _restrict_window_size_for_model(self): if self.model_choice.currentText() == "WNet": self.wnet_enabled = True self.window_size_choice.setCurrentIndex(self._default_window_size) - self.window_infer_box.setChecked(self.wnet_enabled) + self.use_window_choice.setChecked(self.wnet_enabled) self.window_size_choice.setDisabled( self.wnet_enabled and not self.custom_weights_choice.isChecked() ) - self.window_infer_box.setDisabled( + self.use_window_choice.setDisabled( self.wnet_enabled and not self.custom_weights_choice.isChecked() ) @@ -333,6 +350,13 @@ def _toggle_display_thresh(self): self.thresholding_checkbox, self.thresholding_slider.container ) + def _toggle_display_artifact_size_thresh(self): + """Shows the choices for thresholding results depending on whether :py:attr:`self.attempt_artifact_removal_box` is checked.""" + ui.toggle_visibility( + self.attempt_artifact_removal_box, + self.artifact_removal_size.container, + ) + def _toggle_display_crf(self): """Shows the choices for CRF post-processing depending on whether :py:attr:`self.use_crf` is checked.""" ui.toggle_visibility(self.use_crf, self.crf_widgets) @@ -343,7 +367,7 @@ def _toggle_display_instance(self): def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box.""" - ui.toggle_visibility(self.window_infer_box, self.window_infer_params) + ui.toggle_visibility(self.use_window_choice, self.window_infer_params) def _load_weights_path(self): """Show file dialog to set :py:attr:`model_path`.""" @@ -433,7 +457,7 @@ def _build(self): ui.add_widgets( inference_param_group_l, [ - self.window_infer_box, + self.use_window_choice, self.window_infer_params, self.keep_data_on_cpu_box, self.device_choice.label, @@ -811,7 +835,7 @@ def _set_worker_config(self) -> config.InferenceWorkerConfig: instance=self.instance_config, ) - if self.window_infer_box.isChecked(): + if self.use_window_choice.isChecked(): size = int(self.window_size_choice.currentText()) window_config = config.SlidingWindowConfig( window_size=size, From 5298e97dbbd903c07d70b19fbf372c7948b1e56d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Dec 2023 15:12:14 +0100 Subject: [PATCH 02/10] Add utils and inference worker artifact removal --- .../code_models/instance_segmentation.py | 2 + .../code_models/worker_inference.py | 12 +- .../code_plugins/plugin_convert.py | 109 ++++++++++++++++++ .../code_plugins/plugin_model_inference.py | 39 ++++++- .../code_plugins/plugin_utilities.py | 5 + napari_cellseg3d/config.py | 2 + napari_cellseg3d/interface.py | 2 +- 7 files changed, 167 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index 929380b5..abcae370 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -366,6 +366,8 @@ def binary_watershed( def clear_large_objects(image, large_label_size=200): """Uses watershed to label all obejcts, and removes the ones with a volume larger than the specified threshold. + This is intended for artifact removal, and should not be used for instance segmentation. + Args: image: array containing the image large_label_size: size threshold for removal of objects in pixels. E.g. if 10, all objects larger than 10 pixels as a whole will be removed. diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index 7888ebeb..86669495 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -27,7 +27,10 @@ # local from napari_cellseg3d import config, utils from napari_cellseg3d.code_models.crf import crf_with_config -from napari_cellseg3d.code_models.instance_segmentation import volume_stats +from napari_cellseg3d.code_models.instance_segmentation import ( + clear_large_objects, + volume_stats, +) from napari_cellseg3d.code_models.workers_utils import ( PRETRAINED_WEIGHTS_DIR, InferenceResult, @@ -513,6 +516,13 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): ) if self.config.post_process_config.instance.enabled: + if self.config.post_process_config.artifact_removal: + self.log("Removing artifacts...") + semantic_labels = clear_large_objects( + semantic_labels, + self.config.post_process_config.artifact_removal_size, + ) + instance_labels = self.instance_seg( semantic_labels, i + 1, diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index ea90d707..439623de 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -12,6 +12,7 @@ from napari_cellseg3d import utils from napari_cellseg3d.code_models.instance_segmentation import ( InstanceWidgets, + clear_large_objects, clear_small_objects, threshold, to_semantic, @@ -26,6 +27,114 @@ logger = utils.LOGGER +class ArtifactRemovalUtils(BasePluginUtils): + """Class to remove artifacts from images by removing large objects.""" + + save_path = Path.home() / "cellseg3d" / "artifact_removed" + + def __init__(self, viewer: "napari.Viewer.viewer", parent=None): + """Creates a ArtifactRemovalUtils widget. + + Args: + viewer: viewer in which to process data + parent: parent widget + """ + super().__init__( + viewer, + parent, + loads_labels=True, + loads_images=False, + ) + self.data_panel = self._build_io_panel() + self.container = None + self.start_btn = ui.Button("Start", self._start) + + self.artifact_size_counter = ui.IntIncrementCounter( + lower=0, + upper=100000, + default=100, + text_label="Remove all larger than\n(volume in pxs):", + ) + + self.label_layer_loader.layer_list.label.setText("Layer :") + self.label_layer_loader.set_layer_type(napari.layers.Labels) + + self.results_path = str(self.save_path) + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._build() + + def _build(self): + container = ui.ContainerWidget() + self.container = container + + ui.add_widgets( + self.data_panel.layout, + [ + self.data_panel, + ui.add_blank(container), + self.artifact_size_counter.label, + self.artifact_size_counter, + # ui.add_blank(container), + self.start_btn, + ], + ) + container.layout.addWidget(self.data_panel) + ui.ScrollArea.make_scrollable( + container.layout, + self, + max_wh=[MAX_W, MAX_H], + ) + self._set_io_visibility() + container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + return container + + def _remove_large(self, data, size): + return clear_large_objects(data, size).astype(np.uint16) + + def _start(self): + utils.mkdir_from_str(self.results_path) + remove_size = self.artifact_size_counter.value() + + if self.layer_choice.isChecked(): + if self.label_layer_loader.layer_data() is not None: + layer = self.label_layer_loader.layer() + + data = np.array(layer.data) + removed = self._remove_large(data, remove_size) + + utils.save_layer( + self.results_path, + f"artifact_removed_{layer.name}_{utils.get_date_time()}.tif", + removed, + ) + self.layer = utils.show_result( + self._viewer, + layer, + removed, + f"artifact_removed_{layer.name}", + existing_layer=self.layer, + ) + elif ( + self.folder_choice.isChecked() and len(self.labels_filepaths) != 0 + ): + images = [ + self._remove_large(imread(file), remove_size) + for file in self.labels_filepaths + ] + utils.save_folder( + self.results_path, + f"artifact_removed_results_{utils.get_date_time()}", + images, + self.labels_filepaths, + ) + else: + logger.warning("Please specify a layer or a folder") + + class FragmentUtils(BasePluginUtils): """Class to crop large 3D volumes into smaller fragments.""" diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index f1a19dfa..d2759bdb 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -197,6 +197,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): func=self._toggle_display_artifact_size_thresh, parent=self, ) + self.remove_artifacts_label = ui.make_label( + "Remove labels larger than :" + ) self.artifact_removal_size = ui.IntIncrementCounter( lower=1, upper=10000, @@ -204,7 +207,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): text_label="Remove larger than :", step=100, ) - + self.artifact_container = ui.ContainerWidget(parent=self) + self.attempt_artifact_removal_box.toggled.connect( + self._toggle_display_artifact_size_thresh + ) ################## ################## # instance segmentation widgets @@ -216,6 +222,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): func=self._toggle_display_instance, parent=self, ) + self.use_instance_choice.toggled.connect( + self._toggle_artifact_removal_widgets + ) self.use_crf = ui.CheckBox( "Use CRF post-processing", func=self._toggle_display_crf, @@ -300,6 +309,12 @@ def _set_tooltips(self): "Will save several statistics for each object to a csv in the results folder. Stats include : " "volume, centroid coordinates, sphericity" ) + artifact_tooltip = "If enabled, will remove labels of objects larger than the chosen size in instance segmentation" + self.artifact_removal_size.setToolTip( + artifact_tooltip + "\nDefault is 500 pixels" + ) + self.attempt_artifact_removal_box.setToolTip(artifact_tooltip) + self.artifact_container.setToolTip(artifact_tooltip) ################## ################## @@ -354,7 +369,11 @@ def _toggle_display_artifact_size_thresh(self): """Shows the choices for thresholding results depending on whether :py:attr:`self.attempt_artifact_removal_box` is checked.""" ui.toggle_visibility( self.attempt_artifact_removal_box, - self.artifact_removal_size.container, + self.artifact_removal_size, + ) + ui.toggle_visibility( + self.attempt_artifact_removal_box, + self.remove_artifacts_label, ) def _toggle_display_crf(self): @@ -365,6 +384,10 @@ def _toggle_display_instance(self): """Shows or hides the options for instance segmentation based on current user selection.""" ui.toggle_visibility(self.use_instance_choice, self.instance_widgets) + def _toggle_artifact_removal_widgets(self): + """Shows or hides the options for instance segmentation based on current user selection.""" + ui.toggle_visibility(self.use_instance_choice, self.artifact_container) + def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box.""" ui.toggle_visibility(self.use_window_choice, self.window_infer_params) @@ -482,6 +505,15 @@ def _build(self): self.thresholding_slider.container.setVisible(False) + ui.add_widgets( + self.artifact_container.layout, + [ + self.attempt_artifact_removal_box, + self.remove_artifacts_label, + self.artifact_removal_size, + ], + ) + self.artifact_container.setVisible(False) ui.add_widgets( post_proc_layout, [ @@ -492,6 +524,7 @@ def _build(self): self.crf_widgets, self.use_instance_choice, self.instance_widgets, + self.artifact_container, self.save_stats_to_csv_box, # self.instance_param_container, # instance segmentation ], @@ -833,6 +866,8 @@ def _set_worker_config(self) -> config.InferenceWorkerConfig: zoom=zoom_config, thresholding=thresholding_config, instance=self.instance_config, + artifact_removal=self.attempt_artifact_removal_box.isChecked(), + artifact_removal_size=self.artifact_removal_size.value(), ) if self.use_window_choice.isChecked(): diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index c90734ca..538a5410 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -14,6 +14,7 @@ from napari_cellseg3d.code_plugins.plugin_base import BasePluginUtils from napari_cellseg3d.code_plugins.plugin_convert import ( AnisoUtils, + ArtifactRemovalUtils, FragmentUtils, RemoveSmallUtils, StatsUtils, @@ -25,6 +26,8 @@ from napari_cellseg3d.code_plugins.plugin_crop import Cropping from napari_cellseg3d.utils import LOGGER as logger +# NOTE : to add a new utility: add it to the dictionary below, in attr_names in the Utilities class, and import it above + UTILITIES_WIDGETS = { "Crop": Cropping, "Fragment 3D volume": FragmentUtils, @@ -35,6 +38,7 @@ "Threshold": ThresholdUtils, "CRF": CRFWidget, "Label statistics": StatsUtils, + "Clear large labels": ArtifactRemovalUtils, } @@ -57,6 +61,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "thresh", "crf", "stats", + "artifacts", ] self._create_utils_widgets(attr_names) self.utils_choice = ui.DropdownMenu( diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index d3aa0eb1..1b2af90f 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -185,6 +185,8 @@ class PostProcessConfig: zoom: Zoom = Zoom() thresholding: Thresholding = Thresholding() instance: InstanceSegConfig = InstanceSegConfig() + artifact_removal: bool = False + artifact_removal_size: int = 500 @dataclass diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 05baf64c..a1fe754a 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1235,7 +1235,7 @@ def __init__( set_spinbox(self, lower, upper, default, step, fixed) self.label = None - self.container = None + # self.container = None if text_label is not None: self.label = make_label(name=text_label) From 33ff09de8becf145866afdeab8b2f9ce5b0120b1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 4 Dec 2023 15:58:33 +0100 Subject: [PATCH 03/10] Fix issue with layer type in utils --- .../code_plugins/plugin_convert.py | 3 +++ napari_cellseg3d/utils.py | 23 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index 439623de..b723e61c 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -117,6 +117,7 @@ def _start(self): removed, f"artifact_removed_{layer.name}", existing_layer=self.layer, + add_as_labels=True, ) elif ( self.folder_choice.isChecked() and len(self.labels_filepaths) != 0 @@ -407,6 +408,7 @@ def _start(self): removed, f"cleared_{layer.name}", existing_layer=self.layer, + add_as_labels=True, ) elif ( self.folder_choice.isChecked() and len(self.images_filepaths) != 0 @@ -588,6 +590,7 @@ def _start(self): instance, f"instance_{layer.name}", existing_layer=self.layer, + add_as_labels=True, ) elif ( diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 91bd73ae..81b8e766 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -75,6 +75,8 @@ def show_result( name, existing_layer: napari.layers.Layer = None, colormap="bop orange", + add_as_labels=False, + add_as_image=False, ) -> napari.layers.Layer: """Adds layers to a viewer to show result to user. @@ -85,24 +87,35 @@ def show_result( name: name of the added layer existing_layer: existing layer to update, if any colormap: colormap to use for the layer + add_as_labels: whether to add the layer as a Labels layer. Overrides guessing from layer type. + add_as_image: whether to add the layer as an Image layer. Overrides guessing from layer type. Returns: napari.layers.Layer: the layer added to the viewer """ colormap = colormap if colormap is not None else "gray" if existing_layer is None: - if isinstance(layer, napari.layers.Image): + if add_as_image: LOGGER.info("Added resulting image layer") results_layer = viewer.add_image( image, name=name, colormap=colormap ) - elif isinstance(layer, napari.layers.Labels): + elif add_as_labels: LOGGER.info("Added resulting label layer") results_layer = viewer.add_labels(image, name=name) else: - LOGGER.warning( - f"Results not shown, unsupported layer type {type(layer)}" - ) + if isinstance(layer, napari.layers.Image): + LOGGER.info("Added resulting image layer") + results_layer = viewer.add_image( + image, name=name, colormap=colormap + ) + elif isinstance(layer, napari.layers.Labels): + LOGGER.info("Added resulting label layer") + results_layer = viewer.add_labels(image, name=name) + else: + LOGGER.warning( + f"Results not shown, unsupported layer type {type(layer)}" + ) else: try: viewer.layers[existing_layer.name].data = image From 95856c7c332d21a5cffee798c720496e2654a5c1 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 11:24:11 +0100 Subject: [PATCH 04/10] Add better pbar and windowed artifact rem --- .../code_models/instance_segmentation.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py index abcae370..05fd4667 100644 --- a/napari_cellseg3d/code_models/instance_segmentation.py +++ b/napari_cellseg3d/code_models/instance_segmentation.py @@ -192,7 +192,7 @@ def run_method_on_channels_from_params(self, image): return result.squeeze() @staticmethod - def sliding_window(volume, func, patch_size=512): + def sliding_window(volume, func, patch_size=512, increment_labels=True): """Given a volume of dimensions HxWxD, runs the provided function segmentation on the volume using a sliding window of size patch_size. If the edge has been reached, the patch size is reduced to fit the remaining space. @@ -202,6 +202,7 @@ def sliding_window(volume, func, patch_size=512): volume (np.array): The volume to segment func (callable): Function to use for instance segmentation. Should be a partial function with the parameters already set. patch_size (int): The size of the sliding window. + increment_labels (bool): If True, increments the labels of each patch by the maximum label of the previous patch. Returns: np.array: Instance segmentation labels from @@ -209,7 +210,9 @@ def sliding_window(volume, func, patch_size=512): result = np.zeros(volume.shape, dtype=np.uint32) max_label_id = 0 x, y, z = volume.shape - for i in tqdm(range(0, x, patch_size)): + pbar_total = (x // patch_size) * (y // patch_size) * (z // patch_size) + pbar = tqdm(total=pbar_total) + for i in range(0, x, patch_size): for j in range(0, y, patch_size): for k in range(0, z, patch_size): patch = volume[ @@ -220,13 +223,16 @@ def sliding_window(volume, func, patch_size=512): patch_result = func(patch) patch_result = np.array(patch_result) # make sure labels are unique, only where result is not 0 - patch_result[patch_result > 0] += max_label_id + if increment_labels: + patch_result[patch_result > 0] += max_label_id + max_label_id = np.max(patch_result) result[ i : min(i + patch_size, x), j : min(j + patch_size, y), k : min(k + patch_size, z), ] = patch_result - max_label_id = np.max(patch_result) + pbar.update(1) + pbar.close() return result @@ -363,7 +369,7 @@ def binary_watershed( return np.array(segm) -def clear_large_objects(image, large_label_size=200): +def clear_large_objects(image, large_label_size=200, use_window=True): """Uses watershed to label all obejcts, and removes the ones with a volume larger than the specified threshold. This is intended for artifact removal, and should not be used for instance segmentation. @@ -371,10 +377,24 @@ def clear_large_objects(image, large_label_size=200): Args: image: array containing the image large_label_size: size threshold for removal of objects in pixels. E.g. if 10, all objects larger than 10 pixels as a whole will be removed. + use_window: if True, will use a sliding window to perform instance segmentation to avoid memory issues. Default : True Returns: array: The image with large objects removed """ + if use_window: + func = partial( + binary_watershed, + thres_objects=0, + thres_seeding=0, + thres_small=large_label_size, + rem_seed_thres=0, + ) + res = InstanceMethod.sliding_window( + image, func, increment_labels=False + ) + return np.where(res > 0, 0, image) + labeled = binary_watershed( image, thres_objects=0, From 3687ee3c1b83f2313e413e723dd708784087318d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 11:24:31 +0100 Subject: [PATCH 05/10] Refactor results name to be more explicit --- napari_cellseg3d/_tests/test_inference.py | 2 +- napari_cellseg3d/_tests/test_plugin_inference.py | 4 ++-- napari_cellseg3d/code_models/worker_inference.py | 2 +- napari_cellseg3d/code_models/workers_utils.py | 2 +- .../code_plugins/plugin_model_inference.py | 13 ++++++++----- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/napari_cellseg3d/_tests/test_inference.py b/napari_cellseg3d/_tests/test_inference.py index a0f28b83..62972ba1 100644 --- a/napari_cellseg3d/_tests/test_inference.py +++ b/napari_cellseg3d/_tests/test_inference.py @@ -90,7 +90,7 @@ def __call__(self, x): post_process_transforms=mock_work(), ) assert isinstance(res, InferenceResult) - assert res.result is not None + assert res.semantic_segmentation is not None def test_post_processing(): diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index e6d4441c..de518b3d 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -61,7 +61,7 @@ def test_inference(make_napari_viewer_proxy, qtbot): res = next(worker.inference()) assert isinstance(res, InferenceResult) - assert res.result.shape == (8, 8, 8) + assert res.semantic_segmentation.shape == (8, 8, 8) assert res.instance_labels.shape == (8, 8, 8) widget.on_yield(res) @@ -73,7 +73,7 @@ def test_inference(make_napari_viewer_proxy, qtbot): instance_labels=mock_labels, crf_results=mock_image, stats=[volume_stats(mock_labels)], - result=mock_image, + semantic_segmentation=mock_image, model_name="test", ) num_layers = len(viewer.layers) diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index 86669495..8cd74d9b 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -486,7 +486,7 @@ def create_inference_result( instance_labels=instance_labels, crf_results=crf_results, stats=stats, - result=semantic_labels, + semantic_segmentation=semantic_labels, model_name=self.config.model_info.name, ) diff --git a/napari_cellseg3d/code_models/workers_utils.py b/napari_cellseg3d/code_models/workers_utils.py index c8ba6655..05b3e7df 100644 --- a/napari_cellseg3d/code_models/workers_utils.py +++ b/napari_cellseg3d/code_models/workers_utils.py @@ -283,7 +283,7 @@ class InferenceResult: instance_labels: np.array = None crf_results: np.array = None stats: "np.array[ImageStats]" = None - result: np.array = None + semantic_segmentation: np.array = None model_name: str = None diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index d2759bdb..32a63bfd 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -513,7 +513,10 @@ def _build(self): self.artifact_removal_size, ], ) - self.artifact_container.setVisible(False) + # self.attempt_artifact_removal_box.setVisible(False) + self.remove_artifacts_label.setVisible(False) + self.artifact_removal_size.setVisible(False) + ui.add_widgets( post_proc_layout, [ @@ -630,21 +633,21 @@ def _display_results(self, result: InferenceResult): # out_colormap = "twilight" viewer.add_image( - result.result, + result.semantic_segmentation, colormap=out_colormap, name=f"pred_{image_id}_{model_name}", opacity=0.8, ) if ( - len(result.result.shape) == 4 + len(result.semantic_segmentation.shape) == 4 ): # seek channel that is most likely to be foreground fractions_per_channel = utils.channels_fraction_above_threshold( - result.result, 0.5 + result.semantic_segmentation, 0.5 ) index_channel_sorted = np.argsort(fractions_per_channel) for channel in index_channel_sorted: - if result.result[channel].sum() > 0: + if result.semantic_segmentation[channel].sum() > 0: index_channel_least_labelled = channel break viewer.dims.set_point( From a30e28c8dbe8b793a37c18a93679ecd72b4784c3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 11:51:03 +0100 Subject: [PATCH 06/10] Add working pbar to Inference log --- .../code_models/worker_inference.py | 18 ++++++++++++++- napari_cellseg3d/code_models/workers_utils.py | 22 +++++++++++++++++++ .../code_plugins/plugin_model_inference.py | 1 + napari_cellseg3d/interface.py | 6 +++-- 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index 8cd74d9b..c63e7b55 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -1,5 +1,6 @@ """Contains the :py:class:`~InferenceWorker` class, which is a custom worker to run inference jobs in.""" import platform +import sys from pathlib import Path import numpy as np @@ -39,6 +40,7 @@ QuantileNormalization, RemapTensor, Threshold, + TqdmToLogSignal, WeightsDownloader, ) @@ -99,6 +101,7 @@ def __init__( super().__init__(self.inference) self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal + self.log_w_replace_signal = self._signals.log_w_replace_signal self.warn_signal = self._signals.warn_signal self.error_signal = self._signals.error_signal @@ -130,6 +133,14 @@ def log(self, text): """ self.log_signal.emit(text) + def log_w_replacement(self, text): + """Sends a signal that ``text`` should be logged, replacing the last line. + + Args: + text (str): text to logged + """ + self.log_w_replace_signal.emit(text) + def warn(self, warning): """Sends a warning to main thread.""" self.warn_signal.emit(warning) @@ -386,11 +397,14 @@ def model_output_wrapper(inputs): ) return result ########################################## - return post_process_transforms(result) model.eval() with torch.no_grad(): + ### Redirect tqdm pbar to logger + old_stdout = sys.stderr + sys.stderr = TqdmToLogSignal(self.log_w_replacement) + ### outputs = sliding_window_inference( inputs, roi_size=window_size, @@ -403,6 +417,8 @@ def model_output_wrapper(inputs): sigma_scale=0.01, progress=True, ) + ### + sys.stderr = old_stdout except Exception as e: logger.exception(e) logger.debug("failed to run sliding window inference") diff --git a/napari_cellseg3d/code_models/workers_utils.py b/napari_cellseg3d/code_models/workers_utils.py index 05b3e7df..a68b0976 100644 --- a/napari_cellseg3d/code_models/workers_utils.py +++ b/napari_cellseg3d/code_models/workers_utils.py @@ -129,6 +129,8 @@ class LogSignal(WorkerBaseSignals): log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" + log_w_replace_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some text should be logged, replacing the last line""" warn_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" error_signal = Signal(Exception, str) @@ -142,6 +144,26 @@ def __init__(self, parent=None): super().__init__(parent=parent) +class TqdmToLogSignal: + """File-like object to redirect tqdm output to the logger widget in the GUI that self.log emits to.""" + + def __init__(self, log_func): + """Creates a TqdmToLogSignal. + + Args: + log_func (callable): function to call to log the output. + """ + self.log_func = log_func + + def write(self, x): + """Writes the output to the log_func.""" + self.log_func(x.strip()) + + def flush(self): + """Flushes the output. Unused.""" + pass + + class ONNXModelWrapper(torch.nn.Module): """Class to replace torch model by ONNX Runtime session.""" diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 32a63bfd..69e42239 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -751,6 +751,7 @@ def _setup_worker(self): self.worker.started.connect(self.on_start) self.worker.log_signal.connect(self.log.print_and_log) + self.worker.log_w_replace_signal.connect(self.log.replace_last_line) self.worker.warn_signal.connect(self.log.warn) self.worker.error_signal.connect(self.log.error) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index a1fe754a..0b4da851 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -216,8 +216,10 @@ def __init__(self, parent=None): self.lock = threading.Lock() - # def receive_log(self, text): - # self.print_and_log(text) + def flush(self): + """Flush the log.""" + pass + def write(self, message): """Write message to log in a thread-safe manner. From 4d0bdbc2d1d0fe43149ce8a4895d860cf4054b11 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 14:32:14 +0100 Subject: [PATCH 07/10] Reduce requirements --- pyproject.toml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 65963a88..f88c5da2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,24 +29,22 @@ dependencies = [ "numpy", "napari[all]>=0.4.14", "QtPy", - "opencv-python>=4.5.5", +# "opencv-python>=4.5.5", # "dask-image>=0.6.0", "scikit-image>=0.19.2", "matplotlib>=3.4.1", "tifffile>=2022.2.9", - "imageio-ffmpeg>=0.4.5", +# "imageio-ffmpeg>=0.4.5", "imagecodecs>=2023.3.16", "torch>=1.11", "monai[nibabel,einops]>=0.9.0", "itk", "tqdm", - "nibabel", - "scikit-image", - "pillow", +# "nibabel", +# "pillow", "pyclesperanto-prototype", "tqdm", "matplotlib", - "vispy>=0.9.6", ] dynamic = ["version"] From 053600a32edb1e9e469c0e5fc0ed4ac92dd74703 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 14:34:27 +0100 Subject: [PATCH 08/10] Fix spheric. test and import checking --- napari_cellseg3d/_tests/test_utils.py | 4 ++ napari_cellseg3d/code_models/crf.py | 47 ++++++++++++------- .../code_models/worker_inference.py | 4 +- .../code_models/worker_training.py | 23 ++++++--- pyproject.toml | 3 +- 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 00d3737c..a9596432 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -64,6 +64,10 @@ def test_sphericities(): sphericity_axes = 0 except ValueError: sphericity_axes = 0 + if sphericity_axes is None: + sphericity_axes = ( + 0 # errors already handled in function, returns None + ) assert 0 <= sphericity_axes <= 1 diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py index 9dc148b2..b1dcc940 100644 --- a/napari_cellseg3d/code_models/crf.py +++ b/napari_cellseg3d/code_models/crf.py @@ -18,10 +18,23 @@ Implemented using the pydense library available at https://github.com/lucasb-eyer/pydensecrf. """ +import importlib -from warnings import warn +import numpy as np +from napari.qt.threading import GeneratorWorker + +from napari_cellseg3d.config import CRFConfig +from napari_cellseg3d.utils import LOGGER as logger -try: +spec = importlib.util.find_spec("pydensecrf") +CRF_INSTALLED = spec is not None +if not CRF_INSTALLED: + logger.info( + "pydensecrf not installed, CRF post-processing will not be available. " + "Please install by running pip install cellseg3d[crf]" + "This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step." + ) +else: import pydensecrf.densecrf as dcrf from pydensecrf.utils import ( create_pairwise_bilateral, @@ -29,21 +42,23 @@ unary_from_softmax, ) - CRF_INSTALLED = True -except ImportError: - warn( - "pydensecrf not installed, CRF post-processing will not be available. " - "Please install by running pip install cellseg3d[crf]", - stacklevel=1, - ) - CRF_INSTALLED = False - +# try: +# import pydensecrf.densecrf as dcrf +# from pydensecrf.utils import ( +# create_pairwise_bilateral, +# create_pairwise_gaussian, +# unary_from_softmax, +# ) +# CRF_INSTALLED = True +# except (ImportError, ModuleNotFoundError): +# logger.info( +# "pydensecrf not installed, CRF post-processing will not be available. " +# "Please install by running pip install cellseg3d[crf]" +# "This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step." +# ) +# CRF_INSTALLED = False +# use importlib instead to check if pydensecrf is installed -import numpy as np -from napari.qt.threading import GeneratorWorker - -from napari_cellseg3d.config import CRFConfig -from napari_cellseg3d.utils import LOGGER as logger __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" __credits__ = [ diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index c63e7b55..7936363b 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -530,7 +530,8 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): raise ValueError( "An ID should be provided when running from a file" ) - + # old_stderr = sys.stderr + # sys.stderr = TqdmToLogSignal(self.log_w_replacement) if self.config.post_process_config.instance.enabled: if self.config.post_process_config.artifact_removal: self.log("Removing artifacts...") @@ -547,6 +548,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): else: instance_labels = None stats = None + # sys.stderr = old_stderr return instance_labels, stats def save_image( diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 9ba8be40..1f74e9a1 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -67,15 +67,19 @@ try: import wandb + from wandb import ( + init, + ) WANDB_INSTALLED = True -except ImportError: +except (ImportError, ModuleNotFoundError): logger.warning( "wandb not installed, wandb config will not be taken into account", stacklevel=1, ) WANDB_INSTALLED = False + """ Writing something to log messages from outside the main thread needs specific care, Following the instructions in the guides below to have a worker with custom signals, @@ -1109,11 +1113,18 @@ def train( if WANDB_INSTALLED: config_dict = self.config.__dict__ logger.debug(f"wandb config : {config_dict}") - wandb.init( - config=config_dict, - project="CellSeg3D", - mode=self.wandb_config.mode, - ) + try: + wandb.init( + config=config_dict, + project="CellSeg3D", + mode=self.wandb_config.mode, + ) + except AttributeError: + logger.warning( + "Could not initialize wandb." + "This might be due to running napari in a folder where there is a directory named 'wandb'." + "Aborting, please run napari in a different folder or install wandb. Sorry for the inconvenience." + ) if deterministic_config.enabled: set_determinism( diff --git a/pyproject.toml b/pyproject.toml index f88c5da2..04e25721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,8 @@ select = [ # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) # and 'G004' (do not use f-strings in logging) # and 'A003' (Shadowing python builtins) -ignore = ["E501", "E741", "G004", "A003"] +# and 'F401' (imported but unused) +ignore = ["E501", "E741", "G004", "A003", "F401"] exclude = [ ".bzr", ".direnv", From 1bb212833a05193f364f57a381a685eb668de193 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 14:35:34 +0100 Subject: [PATCH 09/10] Update worker_training.py --- napari_cellseg3d/code_models/worker_training.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 1f74e9a1..c67707e3 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -62,9 +62,6 @@ ) logger = utils.LOGGER -VERBOSE_SCHEDULER = True -logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") - try: import wandb from wandb import ( @@ -73,12 +70,14 @@ WANDB_INSTALLED = True except (ImportError, ModuleNotFoundError): - logger.warning( + logger.info( "wandb not installed, wandb config will not be taken into account", stacklevel=1, ) WANDB_INSTALLED = False +VERBOSE_SCHEDULER = True +logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") """ Writing something to log messages from outside the main thread needs specific care, From 8d12de8731f5e0e494640b836fa3ce25fdec8521 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 5 Dec 2023 14:36:37 +0100 Subject: [PATCH 10/10] Update worker_training.py --- napari_cellseg3d/code_models/worker_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index c67707e3..c363fd31 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -68,6 +68,7 @@ init, ) + # used to check if wandb is installed, otherwise not used this way WANDB_INSTALLED = True except (ImportError, ModuleNotFoundError): logger.info(