diff --git a/README.md b/README.md index 635403a6..bd1e12b3 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,8 @@ and launching the plugin from the Plugins menu. You may use the test volume in the `examples` folder to test the inference and review tools. This should run in far less than five minutes on a modern computer. +You may also find a demo Colab notebook in the `notebooks` folder. + ## Issues **Help us make the code better by reporting issues and adding your feature requests!** diff --git a/examples/c5image.tif b/examples/c5image.tif new file mode 100644 index 00000000..e16245fc Binary files /dev/null and b/examples/c5image.tif differ diff --git a/napari_cellseg3d/dev_scripts/remote_inference.py b/napari_cellseg3d/dev_scripts/remote_inference.py new file mode 100644 index 00000000..7a28bf51 --- /dev/null +++ b/napari_cellseg3d/dev_scripts/remote_inference.py @@ -0,0 +1,188 @@ +"""Script to perform inference on a single image and run post-processing on the results, withot napari.""" +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import numpy as np +import torch + +from napari_cellseg3d.code_models.instance_segmentation import ( + clear_large_objects, + clear_small_objects, + threshold, + volume_stats, + voronoi_otsu, +) +from napari_cellseg3d.code_models.worker_inference import InferenceWorker +from napari_cellseg3d.config import ( + InferenceWorkerConfig, + InstanceSegConfig, + ModelInfo, + SlidingWindowConfig, +) +from napari_cellseg3d.utils import resize + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class LogFixture: + """Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers. + + This allows to redirect the output of the workers to stdout instead of a specialized widget. + """ + + def __init__(self): + """Creates a LogFixture object.""" + super(LogFixture, self).__init__() + + def print_and_log(self, text, printing=None): + """Prints and logs text.""" + print(text) + + def warn(self, warning): + """Logs warning.""" + logger.warning(warning) + + def error(self, e): + """Logs error.""" + raise (e) + + +WINDOW_SIZE = 64 + +MODEL_INFO = ModelInfo( + name="SwinUNetR", + model_input_size=64, +) + +CONFIG = InferenceWorkerConfig( + device="cuda" if torch.cuda.is_available() else "cpu", + model_info=MODEL_INFO, + results_path=str(Path("./results").absolute()), + compute_stats=False, + sliding_window_config=SlidingWindowConfig(WINDOW_SIZE, 0.25), +) + + +@dataclass +class PostProcessConfig: + """Config for post-processing.""" + + threshold: float = 0.4 + spot_sigma: float = 0.55 + outline_sigma: float = 0.55 + isotropic_spot_sigma: float = 0.2 + isotropic_outline_sigma: float = 0.2 + anisotropy_correction: List[ + float + ] = None # TODO change to actual values, should be a ratio like [1,1/5,1] + clear_small_size: int = 5 + clear_large_objects: int = 500 + + +def inference_on_images( + image: np.array, config: InferenceWorkerConfig = CONFIG +): + """This function provides inference on an image with minimal config. + + Args: + image (np.array): Image to perform inference on. + config (InferenceWorkerConfig, optional): Config for InferenceWorker. Defaults to CONFIG, see above. + """ + # instance_method = InstanceSegmentationWrapper(voronoi_otsu, {"spot_sigma": 0.7, "outline_sigma": 0.7}) + + config.post_process_config.zoom.enabled = False + config.post_process_config.thresholding.enabled = ( + False # will need to be enabled and set to 0.5 for the test images + ) + config.post_process_config.instance = InstanceSegConfig( + enabled=False, + ) + + config.layer = image + + log = LogFixture() + worker = InferenceWorker(config) + logger.debug(f"Worker config: {worker.config}") + + worker.log_signal.connect(log.print_and_log) + worker.warn_signal.connect(log.warn) + worker.error_signal.connect(log.error) + + worker.log_parameters() + + results = [] + # append the InferenceResult when yielded by worker to results + for result in worker.inference(): + results.append(result) + + return results + + +def post_processing(semantic_segmentation, config: PostProcessConfig = None): + """Run post-processing on inference results.""" + config = PostProcessConfig() if config is None else config + # if config.anisotropy_correction is None: + # config.anisotropy_correction = [1, 1, 1 / 5] + if config.anisotropy_correction is None: + config.anisotropy_correction = [1, 1, 1] + + image = semantic_segmentation + # apply threshold to semantic segmentation + logger.info(f"Thresholding with {config.threshold}") + image = threshold(image, config.threshold) + logger.debug(f"Thresholded image shape: {image.shape}") + # remove artifacts by clearing large objects + logger.info(f"Clearing large objects with {config.clear_large_objects}") + image = clear_large_objects(image, config.clear_large_objects) + # run instance segmentation + logger.info( + f"Running instance segmentation with {config.spot_sigma} and {config.outline_sigma}" + ) + labels = voronoi_otsu( + image, + spot_sigma=config.spot_sigma, + outline_sigma=config.outline_sigma, + ) + # clear small objects + logger.info(f"Clearing small objects with {config.clear_small_size}") + labels = clear_small_objects(labels, config.clear_small_size).astype( + np.uint16 + ) + logger.debug(f"Labels shape: {labels.shape}") + # get volume stats WITH ANISOTROPY + logger.debug(f"NUMBER OF OBJECTS: {np.max(np.unique(labels))-1}") + stats_not_resized = volume_stats(labels) + ######## RUN WITH ANISOTROPY ######## + result_dict = {} + result_dict["Not resized"] = { + "labels": labels, + "stats": stats_not_resized, + } + + if config.anisotropy_correction != [1, 1, 1]: + logger.info("Resizing image to correct anisotropy") + image = resize(image, config.anisotropy_correction) + logger.debug(f"Resized image shape: {image.shape}") + logger.info("Running labels without anisotropy") + labels_resized = voronoi_otsu( + image, + spot_sigma=config.isotropic_spot_sigma, + outline_sigma=config.isotropic_outline_sigma, + ) + logger.info( + f"Clearing small objects with {config.clear_large_objects}" + ) + labels_resized = clear_small_objects( + labels_resized, config.clear_small_size + ).astype(np.uint16) + logger.debug( + f"NUMBER OF OBJECTS: {np.max(np.unique(labels_resized))-1}" + ) + logger.info("Getting volume stats without anisotropy") + stats_resized = volume_stats(labels_resized) + return labels_resized, stats_resized + + return labels, stats_not_resized diff --git a/notebooks/colab_inference_demo.ipynb b/notebooks/colab_inference_demo.ipynb new file mode 100644 index 00000000..84748f36 --- /dev/null +++ b/notebooks/colab_inference_demo.ipynb @@ -0,0 +1,320 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **CellSeg3D : inference demo notebook**\n", + "\n", + "---\n", + "This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).\n", + "\n", + "- 💜 The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **1. Installing dependencies**\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **1.1 Installing CellSeg3D**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@markdown ##Install CellSeg3D and dependencies\n", + "!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git --branch main --single-branch ./CellSeg3D\n", + "!pip install -e CellSeg3D" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **1.2. Restart your runtime**\n", + "---\n", + "\n", + "\n", + "\n", + "** Please ignore the subsequent error message. An automatic restart of your Runtime is expected and is part of the process.**\n", + "\n", + "\"\"
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Force session restart\n", + "exit(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **1.3 Load key dependencies**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Load libraries\n", + "from pathlib import Path\n", + "from tifffile import imread\n", + "from napari_cellseg3d.dev_scripts import remote_inference as cs3d\n", + "from napari_cellseg3d.utils import LOGGER as logger\n", + "import logging\n", + "\n", + "logger.setLevel(logging.INFO)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **2. Inference**\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## **2.1. Check for GPU access**\n", + "---\n", + "\n", + "By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:\n", + "\n", + "Navigate to Runtime and select Change the Runtime type.\n", + "\n", + "For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).\n", + "\n", + "Under Accelerator, choose GPU (Graphics Processing Unit).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@markdown This cell verifies if GPU access is available.\n", + "\n", + "import torch\n", + "if not torch.cuda.is_available():\n", + " print('You do not have GPU access.')\n", + " print('Did you change your runtime?')\n", + " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", + " print('Expect slow performance. To access GPU try reconnecting later')\n", + "\n", + "else:\n", + " print('You have GPU access')\n", + " !nvidia-smi\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **2.2 Run inference**\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Load demo image and inference configuration\n", + "#@markdown This cell loads a demo image and load the inference configuration.\n", + "demo_image_path = \"./CellSeg3D/examples/c5image.tif\n", + "demo_image = imread(demo_image_path)\n", + "inference_config = cs3d.CONFIG\n", + "post_process_config = cs3d.PostProcessConfig()\n", + "# select cle device for colab\n", + "import pyclesperanto_prototype as cle\n", + "cle.select_device(\"cupy\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Run inference on demo image\n", + "#@markdown This cell runs the inference on the demo image.\n", + "result = cs3d.inference_on_images(\n", + " demo_image,\n", + " config=inference_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Post-process the result\n", + "# @markdown This cell post-processes the result of the inference : thresholding, instance segmentation, and statistics.\n", + "instance_segmentation,stats = cs3d.post_processing(\n", + " result[0].semantic_segmentation,\n", + " config=post_process_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Display the result\n", + "#@markdown This cell displays the result of the inference and post-processing. Use the slider to navigate through the z-stack.\n", + "# @markdown *KNOWN ISSUE* : The colormap of the labels is not consistent between the z-stacks. \n", + "import matplotlib.pyplot as plt\n", + "import ipywidgets as widgets\n", + "from IPython.display import display\n", + "import matplotlib\n", + "import colorsys\n", + "import numpy as np\n", + "\n", + "def random_label_cmap(n=2**16, h = (0,1), l = (.4,1), s =(.2,.8)):\n", + " \"\"\"FUNCTION TAKEN FROM STARDIST REPO : https://github.com/stardist/stardist/blob/c6c261081c6f9717fa9f5c47720ad2d5a9153224/stardist/plot/plot.py#L8\"\"\"\n", + " h,l,s = np.random.uniform(*h,n), np.random.uniform(*l,n), np.random.uniform(*s,n)\n", + " cols = np.stack([colorsys.hls_to_rgb(_h,_l,_s) for _h,_l,_s in zip(h,l,s)],axis=0)\n", + " cols[0] = 0\n", + " # reset the random generator to the first draw to keep the colormap consistent\n", + "\n", + " return matplotlib.colors.ListedColormap(cols)\n", + "\n", + "label_cmap = random_label_cmap(n=instance_segmentation.max()+1)\n", + "\n", + "def update_plot(z):\n", + " plt.figure(figsize=(15, 15))\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(demo_image[z], cmap='gray')\n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(result[0].semantic_segmentation[z], cmap='turbo')\n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(instance_segmentation[z], cmap=label_cmap)\n", + " plt.show()\n", + "\n", + "# Create a slider\n", + "z_slider = widgets.IntSlider(min=0, max=demo_image.shape[0]-1, step=1, value=demo_image.shape[0] // 2)\n", + "\n", + "# Display the slider and update the plot when the slider is changed\n", + "widgets.interact(update_plot, z=z_slider)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Display the statistics\n", + "# @markdown This cell displays the statistics of the post-processed result.\n", + "import pandas as pd\n", + "data = pd.DataFrame(stats.get_dict())\n", + "display(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Plot the a 3D view, with statistics\n", + "# @markdown This cell plots a 3D view of the cells, with the volume as the size of the points and the sphericity as the color.\n", + "import plotly.graph_objects as go\n", + "import numpy as np\n", + "\n", + "def plotly_cells_stats(data):\n", + "\n", + " x = data[\"Centroid x\"]\n", + " y = data[\"Centroid y\"]\n", + " z = data[\"Centroid z\"]\n", + "\n", + " fig = go.Figure(\n", + " data=go.Scatter3d(\n", + " x=np.floor(x),\n", + " y=np.floor(y),\n", + " z=np.floor(z),\n", + " mode=\"markers\",\n", + " marker=dict(\n", + " sizemode=\"diameter\",\n", + " sizeref=30,\n", + " sizemin=20,\n", + " size=data[\"Volume\"],\n", + " color=data[\"Sphericity (axes)\"],\n", + " colorscale=\"Turbo_r\",\n", + " colorbar_title=\"Sphericity\",\n", + " line_color=\"rgb(140, 140, 170)\",\n", + " ),\n", + " )\n", + " )\n", + "\n", + " fig.update_layout(\n", + " height=600,\n", + " width=600,\n", + " title=f'Total number of cells : {int(data[\"Number objects\"][0])}',\n", + " )\n", + "\n", + " fig.show(renderer=\"colab\")\n", + " \n", + "plotly_cells_stats(data)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}