From 46d20df3d9f4546d1baec525d69ae6570b8d1c2f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Tue, 23 Apr 2024 10:57:47 +0200 Subject: [PATCH 1/8] Add optional train/val split for supervised models --- .../code_models/worker_training.py | 20 +++++++++++++++++-- napari_cellseg3d/config.py | 2 ++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 9ad08d24..9d661f6e 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -1203,7 +1203,10 @@ def train( epoch_loss_values = [] val_metric_values = [] - if len(self.config.train_data_dict) > 1: + if ( + len(self.config.train_data_dict) > 1 + and self.config.eval_data_dict is None + ): self.train_files, self.val_files = ( self.config.train_data_dict[ 0 : int( @@ -1218,6 +1221,11 @@ def train( ) : ], ) + elif self.config.eval_data_dict is not None: + # train files are used as is, validation files are from eval_data_dict + # not used in the plugin yet, only for training via the API + self.train_files = self.config.train_data_dict + self.val_files = self.config.eval_data_dict else: self.train_files = self.val_files = self.config.train_data_dict msg = f"Only one image file was provided : {self.config.train_data_dict[0]['image']}.\n" @@ -1591,6 +1599,10 @@ def get_patch_loader_func(num_samples): val_data["image"].to(device), val_data["label"].to(device), ) + if self.labels_not_semantic: + val_labels = torch.where( + val_labels > 1, 1, val_labels + ) try: with torch.no_grad(): @@ -1624,7 +1636,11 @@ def get_patch_loader_func(num_samples): EnsureType(), ] ) # - post_label = EnsureType() + post_label = Compose( + [ + EnsureType(), + ] + ) output_raw = [ RemapTensor(new_max=1, new_min=0)(t) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 002c6024..7f8baf05 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -348,11 +348,13 @@ class SupervisedTrainingWorkerConfig(TrainingWorkerConfig): """Class to record config for Trainer plugin. Args: + eval_data_dict (dict): dict of eval data as {"image": np.array, "labels": np.array}. Optional. model_info (ModelInfo): model info loss_function (callable): loss function validation_percent (float): validation percent """ + eval_data_dict: dict = None model_info: ModelInfo = None loss_function: callable = None training_percent: float = 0.8 From 81984783249e06f2b4cdc62a30e56cddad0ef652 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 29 Apr 2024 11:57:16 +0200 Subject: [PATCH 2/8] Update worker_training.py --- napari_cellseg3d/code_models/worker_training.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 9d661f6e..4acbd65a 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -402,7 +402,11 @@ def log_parameters(self): self.log("*" * 20) def train( - self, provided_model=None, provided_optimizer=None, provided_loss=None + self, + provided_model=None, + provided_optimizer=None, + provided_loss=None, + wandb_name_override=None, ): """Main training function. @@ -412,6 +416,7 @@ def train( provided_model (WNet, optional): A model to use for training. Defaults to None. provided_optimizer (torch.optim.Optimizer, optional): An optimizer to use for training. Defaults to None. provided_loss (torch.nn.Module, optional): A loss function to use for training. Defaults to None. + wandb_name_override (str, optional): A name to override the wandb run name. Defaults to None. """ try: if self.config is None: @@ -431,7 +436,9 @@ def train( wandb.init( config=config_dict, project="CellSeg3D - WNet", - name=f"WNet_training - {utils.get_date_time()}", + name=f"WNet_training - {utils.get_date_time()}" + if wandb_name_override is None + else wandb_name_override, mode=self.wandb_config.mode, tags=["WNet", "training"], ) @@ -1079,6 +1086,7 @@ def train( provided_optimizer=None, provided_loss=None, provided_scheduler=None, + wandb_name_override=None, ): """Trains the PyTorch model for the given number of epochs. @@ -1142,7 +1150,9 @@ def train( wandb.init( config=config_dict, project="CellSeg3D", - name=f"{model_config.name}_supervised_training - {utils.get_date_time()}", + name=f"{model_config.name}_supervised_training - {utils.get_date_time()}" + if wandb_name_override is None + else wandb_name_override, tags=[f"{model_config.name}", "supervised"], mode=self.wandb_config.mode, ) From f1e46ae36982619e7ced439b73b3a82a324f06c3 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 14:24:25 +0200 Subject: [PATCH 3/8] Change name to WNet3D --- .../models/pretrained/pretrained_model_urls.json | 1 + napari_cellseg3d/code_models/models/wnet/model.py | 2 +- napari_cellseg3d/code_plugins/plugin_model_inference.py | 4 ++-- napari_cellseg3d/code_plugins/plugin_model_training.py | 4 ++-- napari_cellseg3d/dev_scripts/colab_training.py | 6 +++--- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index d9e1e4b0..8bf72b27 100644 --- a/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -4,6 +4,7 @@ "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet_latest.tar.gz", "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SwinUNetR_latest.tar.gz", "WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_latest.tar.gz", + "WNet3D": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_latest.tar.gz", "WNet_ONNX": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_onnx.tar.gz", "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" } diff --git a/napari_cellseg3d/code_models/models/wnet/model.py b/napari_cellseg3d/code_models/models/wnet/model.py index 4817e307..e0e5b764 100644 --- a/napari_cellseg3d/code_models/models/wnet/model.py +++ b/napari_cellseg3d/code_models/models/wnet/model.py @@ -17,7 +17,7 @@ class WNet_encoder(nn.Module): - """WNet with encoder only.""" + """WNet3D with encoder only.""" def __init__( self, diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 27ff2d5e..fd4c37e5 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -258,7 +258,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self._remove_unused() def _toggle_crf_choice(self): - if self.model_choice.currentText() == "WNet": + if self.model_choice.currentText() == "WNet3D": self.use_crf.setVisible(True) else: self.use_crf.setVisible(False) @@ -335,7 +335,7 @@ def check_ready(self): def _restrict_window_size_for_model(self): """Sets the window size to a value that is compatible with the chosen model.""" self.wnet_enabled = False - if self.model_choice.currentText() == "WNet": + if self.model_choice.currentText() == "WNet3D": self.wnet_enabled = True self.window_size_choice.setCurrentIndex(self._default_window_size) self.use_window_choice.setChecked(self.wnet_enabled) diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index a668e155..67af40e8 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -451,7 +451,7 @@ def check_ready(self): def _toggle_unsupervised_mode(self, enabled=False): """Change all the UI elements needed for unsupervised learning mode.""" - if self.model_choice.currentText() == "WNet" or enabled: + if self.model_choice.currentText() == "WNet3D" or enabled: unsupervised = True self.start_btn = self.start_button_unsupervised if self.image_filewidget.text_field.text() == "Images directory": @@ -799,7 +799,7 @@ def _build(self): ui.add_blank(advanced_tab, advanced_tab.layout) ################## model_params_group_w, model_params_group_l = ui.make_group( - "WNet parameters", r=20, b=5, t=11 + "WNet3D parameters", r=20, b=5, t=11 ) ui.add_widgets( model_params_group_l, diff --git a/napari_cellseg3d/dev_scripts/colab_training.py b/napari_cellseg3d/dev_scripts/colab_training.py index ef5245e8..e8ca1f91 100644 --- a/napari_cellseg3d/dev_scripts/colab_training.py +++ b/napari_cellseg3d/dev_scripts/colab_training.py @@ -332,7 +332,7 @@ def train( project="CellSeg3D (Colab)", name=f"{self.config.model_info.name} training - {utils.get_date_time()}", mode=self.wandb_config.mode, - tags=["WNet", "Colab"], + tags=["WNet3D", "Colab"], ) set_determinism(seed=self.config.deterministic_config.seed) @@ -379,7 +379,7 @@ def train( if self.config.weights_info.use_custom: if self.config.weights_info.use_pretrained: weights_file = "wnet.pth" - self.downloader.download_weights("WNet", weights_file) + self.downloader.download_weights("WNet3D", weights_file) weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file) self.config.weights_info.path = weights else: @@ -596,7 +596,7 @@ def train( if WANDB_INSTALLED and self.wandb_config.save_model_artifact: model_artifact = wandb.Artifact( - "WNet", + "WNet3D", type="model", description="CellSeg3D WNet", metadata=self.config.__dict__, From 066ea703abd7849b55a25b04cdef83037b8023d0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 14:26:42 +0200 Subject: [PATCH 4/8] Update worker_training.py --- .../code_models/worker_training.py | 47 +++++-------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 4acbd65a..0ffec8f9 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -402,11 +402,7 @@ def log_parameters(self): self.log("*" * 20) def train( - self, - provided_model=None, - provided_optimizer=None, - provided_loss=None, - wandb_name_override=None, + self, provided_model=None, provided_optimizer=None, provided_loss=None ): """Main training function. @@ -416,7 +412,6 @@ def train( provided_model (WNet, optional): A model to use for training. Defaults to None. provided_optimizer (torch.optim.Optimizer, optional): An optimizer to use for training. Defaults to None. provided_loss (torch.nn.Module, optional): A loss function to use for training. Defaults to None. - wandb_name_override (str, optional): A name to override the wandb run name. Defaults to None. """ try: if self.config is None: @@ -436,11 +431,9 @@ def train( wandb.init( config=config_dict, project="CellSeg3D - WNet", - name=f"WNet_training - {utils.get_date_time()}" - if wandb_name_override is None - else wandb_name_override, + name=f"WNet3D_training - {utils.get_date_time()}", mode=self.wandb_config.mode, - tags=["WNet", "training"], + tags=["WNet3D", "training"], ) set_determinism(seed=self.config.deterministic_config.seed) @@ -494,7 +487,7 @@ def train( ) weights_file = WNet_.weights_file - self.downloader.download_weights("WNet", weights_file) + self.downloader.download_weights("WNet3D", weights_file) weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file)) self.config.weights_info.path = weights @@ -788,7 +781,7 @@ def train( if WANDB_INSTALLED and self.wandb_config.save_model_artifact: model_artifact = wandb.Artifact( - "WNet", + "WNet3D", type="model", description="CellSeg3D WNet", metadata=self.config.__dict__, @@ -1086,7 +1079,6 @@ def train( provided_optimizer=None, provided_loss=None, provided_scheduler=None, - wandb_name_override=None, ): """Trains the PyTorch model for the given number of epochs. @@ -1135,6 +1127,11 @@ def train( weights_config = self.config.weights_info deterministic_config = self.config.deterministic_config + if self.config.device == "mps": + from os import environ + + environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + start_time = time.time() try: @@ -1150,9 +1147,7 @@ def train( wandb.init( config=config_dict, project="CellSeg3D", - name=f"{model_config.name}_supervised_training - {utils.get_date_time()}" - if wandb_name_override is None - else wandb_name_override, + name=f"{model_config.name}_supervised_training - {utils.get_date_time()}", tags=[f"{model_config.name}", "supervised"], mode=self.wandb_config.mode, ) @@ -1213,10 +1208,7 @@ def train( epoch_loss_values = [] val_metric_values = [] - if ( - len(self.config.train_data_dict) > 1 - and self.config.eval_data_dict is None - ): + if len(self.config.train_data_dict) > 1: self.train_files, self.val_files = ( self.config.train_data_dict[ 0 : int( @@ -1231,11 +1223,6 @@ def train( ) : ], ) - elif self.config.eval_data_dict is not None: - # train files are used as is, validation files are from eval_data_dict - # not used in the plugin yet, only for training via the API - self.train_files = self.config.train_data_dict - self.val_files = self.config.eval_data_dict else: self.train_files = self.val_files = self.config.train_data_dict msg = f"Only one image file was provided : {self.config.train_data_dict[0]['image']}.\n" @@ -1609,10 +1596,6 @@ def get_patch_loader_func(num_samples): val_data["image"].to(device), val_data["label"].to(device), ) - if self.labels_not_semantic: - val_labels = torch.where( - val_labels > 1, 1, val_labels - ) try: with torch.no_grad(): @@ -1646,11 +1629,7 @@ def get_patch_loader_func(num_samples): EnsureType(), ] ) # - post_label = Compose( - [ - EnsureType(), - ] - ) + post_label = EnsureType() output_raw = [ RemapTensor(new_max=1, new_min=0)(t) From 350fecc80e55d5345b7aede4488062e9ddc3839e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 15:01:42 +0200 Subject: [PATCH 5/8] Replace CellSeg3d with CellSeg3D --- .napari/DESCRIPTION.md | 2 +- .napari/config.yml | 10 +++++----- README.md | 14 +++++++------- docs/_config.yml | 2 +- docs/source/guides/custom_model_template.rst | 2 +- docs/source/guides/detailed_walkthrough.rst | 4 ++-- docs/source/guides/installation_guide.rst | 6 +++--- docs/source/guides/training_wnet.rst | 2 +- examples/README.md | 7 +++++++ napari_cellseg3d/code_plugins/plugin_helper.py | 4 ++-- napari_cellseg3d/dev_scripts/classifier_test.ipynb | 2 +- .../dev_scripts/test_new_evaluation.ipynb | 6 +++--- notebooks/colab_wnet_training.ipynb | 4 ++-- pyproject.toml | 4 ++-- 14 files changed, 38 insertions(+), 31 deletions(-) create mode 100644 examples/README.md diff --git a/.napari/DESCRIPTION.md b/.napari/DESCRIPTION.md index 9a3143bb..d379ed68 100644 --- a/.napari/DESCRIPTION.md +++ b/.napari/DESCRIPTION.md @@ -117,7 +117,7 @@ this information here. ## Getting Help If you would like to report an issue with the plugin, -please open an [issue on Github](https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues) +please open an [issue on Github](https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues) 34\u001b[0m view \u001b[39m=\u001b[39m napari\u001b[39m.\u001b[39;49mview_image(preds, colormap\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mturbo\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[0;32m 35\u001b[0m view\u001b[39m.\u001b[39madd_image(test_image, colormap\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mturbo\u001b[39m\u001b[39m\"\u001b[39m, blending\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39madditive\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 36\u001b[0m view\u001b[39m.\u001b[39madd_image(rejected, colormap\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mturbo\u001b[39m\u001b[39m\"\u001b[39m, blending\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39madditive\u001b[39m\u001b[39m\"\u001b[39m)\n", + "\u001b[1;32mc:\\Users\\Cyril\\Desktop\\Code\\CellSeg3D\\napari_cellseg3d\\dev_scripts\\classifier_test.ipynb Cell 12\u001b[0m line \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 31\u001b[0m preds[crop_location_i:crop_location_i\u001b[39m+\u001b[39mcube_size, crop_location_j:crop_location_j\u001b[39m+\u001b[39mcube_size, crop_location_k:crop_location_k\u001b[39m+\u001b[39mcube_size] \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[0;32m 32\u001b[0m rejected[crop_location_i:crop_location_i\u001b[39m+\u001b[39mcube_size, crop_location_j:crop_location_j\u001b[39m+\u001b[39mcube_size, crop_location_k:crop_location_k\u001b[39m+\u001b[39mcube_size] \u001b[39m=\u001b[39m crop\n\u001b[1;32m---> 34\u001b[0m view \u001b[39m=\u001b[39m napari\u001b[39m.\u001b[39;49mview_image(preds, colormap\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mturbo\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[0;32m 35\u001b[0m view\u001b[39m.\u001b[39madd_image(test_image, colormap\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mturbo\u001b[39m\u001b[39m\"\u001b[39m, blending\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39madditive\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 36\u001b[0m view\u001b[39m.\u001b[39madd_image(rejected, colormap\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mturbo\u001b[39m\u001b[39m\"\u001b[39m, blending\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39madditive\u001b[39m\u001b[39m\"\u001b[39m)\n", "File \u001b[1;32mc:\\Users\\Cyril\\anaconda3\\envs\\cellseg3d\\lib\\site-packages\\napari\\view_layers.py:178\u001b[0m, in \u001b[0;36mview_image\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 176\u001b[0m \u001b[39m@_merge_layer_viewer_sigs_docs\u001b[39m\n\u001b[0;32m 177\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mview_image\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m--> 178\u001b[0m \u001b[39mreturn\u001b[39;00m _make_viewer_then(\u001b[39m'\u001b[39;49m\u001b[39madd_image\u001b[39;49m\u001b[39m'\u001b[39;49m, args, kwargs)[\u001b[39m0\u001b[39m]\n", "File \u001b[1;32mc:\\Users\\Cyril\\anaconda3\\envs\\cellseg3d\\lib\\site-packages\\napari\\view_layers.py:156\u001b[0m, in \u001b[0;36m_make_viewer_then\u001b[1;34m(add_method, args, kwargs)\u001b[0m\n\u001b[0;32m 154\u001b[0m viewer \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mpop(\u001b[39m\"\u001b[39m\u001b[39mviewer\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mNone\u001b[39;00m)\n\u001b[0;32m 155\u001b[0m \u001b[39mif\u001b[39;00m viewer \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m--> 156\u001b[0m viewer \u001b[39m=\u001b[39m Viewer(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mvkwargs)\n\u001b[0;32m 157\u001b[0m kwargs\u001b[39m.\u001b[39mupdate(kwargs\u001b[39m.\u001b[39mpop(\u001b[39m\"\u001b[39m\u001b[39mkwargs\u001b[39m\u001b[39m\"\u001b[39m, {}))\n\u001b[0;32m 158\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(viewer, add_method)\n", "File \u001b[1;32mc:\\Users\\Cyril\\anaconda3\\envs\\cellseg3d\\lib\\site-packages\\napari\\viewer.py:67\u001b[0m, in \u001b[0;36mViewer.__init__\u001b[1;34m(self, title, ndisplay, order, axis_labels, show)\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mnapari\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mwindow\u001b[39;00m \u001b[39mimport\u001b[39;00m Window\n\u001b[0;32m 65\u001b[0m _initialize_plugins()\n\u001b[1;32m---> 67\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_window \u001b[39m=\u001b[39m Window(\u001b[39mself\u001b[39;49m, show\u001b[39m=\u001b[39;49mshow)\n\u001b[0;32m 68\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_instances\u001b[39m.\u001b[39madd(\u001b[39mself\u001b[39m)\n", diff --git a/napari_cellseg3d/dev_scripts/test_new_evaluation.ipynb b/napari_cellseg3d/dev_scripts/test_new_evaluation.ipynb index 12707e9b..dcb7ace9 100644 --- a/napari_cellseg3d/dev_scripts/test_new_evaluation.ipynb +++ b/napari_cellseg3d/dev_scripts/test_new_evaluation.ipynb @@ -24,7 +24,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -62,8 +62,8 @@ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[16], line 4\u001b[0m\n\u001b[0;32m 2\u001b[0m labels \u001b[38;5;241m=\u001b[39m imread(path_model_label)\n\u001b[0;32m 3\u001b[0m \u001b[38;5;66;03m# labels.shape\u001b[39;00m\n\u001b[1;32m----> 4\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mevl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate_model_performance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_true_labels\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43mvisualize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_graphical_summary\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mplot_according_to_gt_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32m~\\Desktop\\Code\\CellSeg3d\\napari_cellseg3d\\dev_scripts\\evaluate_labels.py:58\u001b[0m, in \u001b[0;36mevaluate_model_performance\u001b[1;34m(labels, model_labels, threshold_correct, print_details, visualize, return_graphical_summary, plot_according_to_gt_label)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Evaluate the model performance.\u001b[39;00m\n\u001b[0;32m 21\u001b[0m \u001b[38;5;124;03mParameters\u001b[39;00m\n\u001b[0;32m 22\u001b[0m \u001b[38;5;124;03m----------\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[38;5;124;03mgraph_true_positive_ratio_model: ndarray\u001b[39;00m\n\u001b[0;32m 56\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 57\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMapping labels...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 58\u001b[0m tmp \u001b[38;5;241m=\u001b[39m \u001b[43mmap_labels\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_labels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 61\u001b[0m \u001b[43m \u001b[49m\u001b[43mthreshold_correct\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_total_number_gt_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_graphical_summary\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_graphical_summary\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_according_to_gt_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplot_according_to_gt_label\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 66\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_graphical_summary:\n\u001b[0;32m 68\u001b[0m (\n\u001b[0;32m 69\u001b[0m map_labels_existing,\n\u001b[0;32m 70\u001b[0m map_fused_neurons,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 75\u001b[0m graph_true_positive_ratio_model,\n\u001b[0;32m 76\u001b[0m ) \u001b[38;5;241m=\u001b[39m tmp\n", - "File \u001b[1;32m~\\Desktop\\Code\\CellSeg3d\\napari_cellseg3d\\dev_scripts\\evaluate_labels.py:422\u001b[0m, in \u001b[0;36mmap_labels\u001b[1;34m(gt_labels, model_labels, threshold_correct, return_total_number_gt_labels, return_dict_map, accuracy_function, return_graphical_summary, plot_according_to_gt_labels)\u001b[0m\n\u001b[0;32m 419\u001b[0m \u001b[38;5;66;03m# remove from new_labels the labels that are in map_labels_existing\u001b[39;00m\n\u001b[0;32m 420\u001b[0m new_labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(new_labels)\n\u001b[0;32m 421\u001b[0m i_new_labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39misin(\n\u001b[1;32m--> 422\u001b[0m \u001b[43mnew_labels\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdict_map\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_label\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m,\n\u001b[0;32m 423\u001b[0m map_labels_existing[:, dict_map[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_label\u001b[39m\u001b[38;5;124m\"\u001b[39m]],\n\u001b[0;32m 424\u001b[0m invert\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m 425\u001b[0m )\n\u001b[0;32m 426\u001b[0m new_labels \u001b[38;5;241m=\u001b[39m new_labels[i_new_labels, :]\n\u001b[0;32m 427\u001b[0m \u001b[38;5;66;03m# find the fused neurons: multiple gt labels are mapped to the same model label\u001b[39;00m\n", + "File \u001b[1;32m~\\Desktop\\Code\\CellSeg3D\\napari_cellseg3d\\dev_scripts\\evaluate_labels.py:58\u001b[0m, in \u001b[0;36mevaluate_model_performance\u001b[1;34m(labels, model_labels, threshold_correct, print_details, visualize, return_graphical_summary, plot_according_to_gt_label)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Evaluate the model performance.\u001b[39;00m\n\u001b[0;32m 21\u001b[0m \u001b[38;5;124;03mParameters\u001b[39;00m\n\u001b[0;32m 22\u001b[0m \u001b[38;5;124;03m----------\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[38;5;124;03mgraph_true_positive_ratio_model: ndarray\u001b[39;00m\n\u001b[0;32m 56\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 57\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMapping labels...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 58\u001b[0m tmp \u001b[38;5;241m=\u001b[39m \u001b[43mmap_labels\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_labels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 61\u001b[0m \u001b[43m \u001b[49m\u001b[43mthreshold_correct\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_total_number_gt_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_graphical_summary\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_graphical_summary\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_according_to_gt_labels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplot_according_to_gt_label\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 66\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_graphical_summary:\n\u001b[0;32m 68\u001b[0m (\n\u001b[0;32m 69\u001b[0m map_labels_existing,\n\u001b[0;32m 70\u001b[0m map_fused_neurons,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 75\u001b[0m graph_true_positive_ratio_model,\n\u001b[0;32m 76\u001b[0m ) \u001b[38;5;241m=\u001b[39m tmp\n", + "File \u001b[1;32m~\\Desktop\\Code\\CellSeg3D\\napari_cellseg3d\\dev_scripts\\evaluate_labels.py:422\u001b[0m, in \u001b[0;36mmap_labels\u001b[1;34m(gt_labels, model_labels, threshold_correct, return_total_number_gt_labels, return_dict_map, accuracy_function, return_graphical_summary, plot_according_to_gt_labels)\u001b[0m\n\u001b[0;32m 419\u001b[0m \u001b[38;5;66;03m# remove from new_labels the labels that are in map_labels_existing\u001b[39;00m\n\u001b[0;32m 420\u001b[0m new_labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(new_labels)\n\u001b[0;32m 421\u001b[0m i_new_labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39misin(\n\u001b[1;32m--> 422\u001b[0m \u001b[43mnew_labels\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdict_map\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_label\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m,\n\u001b[0;32m 423\u001b[0m map_labels_existing[:, dict_map[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_label\u001b[39m\u001b[38;5;124m\"\u001b[39m]],\n\u001b[0;32m 424\u001b[0m invert\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m 425\u001b[0m )\n\u001b[0;32m 426\u001b[0m new_labels \u001b[38;5;241m=\u001b[39m new_labels[i_new_labels, :]\n\u001b[0;32m 427\u001b[0m \u001b[38;5;66;03m# find the fused neurons: multiple gt labels are mapped to the same model label\u001b[39;00m\n", "\u001b[1;31mIndexError\u001b[0m: too many indices for array: array is 1-dimensional, but 2 were indexed" ] } diff --git a/notebooks/colab_wnet_training.ipynb b/notebooks/colab_wnet_training.ipynb index 5b984673..b5b4c98a 100644 --- a/notebooks/colab_wnet_training.ipynb +++ b/notebooks/colab_wnet_training.ipynb @@ -24,7 +24,7 @@ "---\n", "*Disclaimer:*\n", "\n", - "This notebook, part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) under the [Mathis Lab of Adaptive Motor Control](https://www.mackenziemathislab.org/), is a work-in-progress resource for training the WNet model for unsupervised cell segmentation.\n", + "This notebook, part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3D) under the [Mathis Lab of Adaptive Motor Control](https://www.mackenziemathislab.org/), is a work-in-progress resource for training the WNet model for unsupervised cell segmentation.\n", "\n", "The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project —a collaborative effort between the Jacquemet and Henriques laboratories, and created by Daniel Krentzel. Except for the model provided herein, all credits are duly given to their team." ], @@ -226,7 +226,7 @@ ], "source": [ "#@markdown ##Play to install WNet dependencies\n", - "!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git --branch cy/wnet-extras --single-branch ./CellSeg3D\n", + "!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3D.git --branch cy/wnet-extras --single-branch ./CellSeg3D\n", "!pip install -e CellSeg3D" ] }, diff --git a/pyproject.toml b/pyproject.toml index 42a52f37..392cc99e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,9 +49,9 @@ dependencies = [ dynamic = ["version", "entry-points"] [project.urls] -Homepage = "https://github.com/AdaptiveMotorControlLab/CellSeg3d" +Homepage = "https://github.com/AdaptiveMotorControlLab/CellSeg3D" Documentation = "https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html" -Issues = "https://github.com/AdaptiveMotorControlLab/CellSeg3d/issues" +Issues = "https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues" [build-system] requires = ["setuptools", "wheel"] From 79b75e2b9cd7f7d976a6abf9cf712297836359be Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 15:14:53 +0200 Subject: [PATCH 6/8] Update test_training.py --- napari_cellseg3d/_tests/test_training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index c095b4e9..2c5f7b7f 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -87,7 +87,8 @@ def test_unsupervised_training(make_napari_viewer_proxy): widget.log = LogFixture() widget.worker = None widget._toggle_unsupervised_mode(enabled=True) - widget.model_choice.setCurrentText("WNet") + widget.model_choice.setCurrentText("WNet3D") + widget._toggle_unsupervised_mode(enabled=True) widget.patch_choice.setChecked(True) [w.setValue(4) for w in widget.patch_size_widgets] From 843199bf9673d9490db837ebd4e5b76e5ed1d9fd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 15:15:22 +0200 Subject: [PATCH 7/8] Update test_training.py --- napari_cellseg3d/_tests/test_training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 2c5f7b7f..15ec119e 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -86,7 +86,6 @@ def test_unsupervised_training(make_napari_viewer_proxy): widget = Trainer(viewer) widget.log = LogFixture() widget.worker = None - widget._toggle_unsupervised_mode(enabled=True) widget.model_choice.setCurrentText("WNet3D") widget._toggle_unsupervised_mode(enabled=True) From c6da7c88bd710905ee91b49a45125bbccb2f48e2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 6 May 2024 13:21:10 +0200 Subject: [PATCH 8/8] Fix missing Wnet3d + fix tests (utils) --- napari_cellseg3d/_tests/test_plugin_inference.py | 2 +- napari_cellseg3d/_tests/test_plugin_training.py | 2 +- napari_cellseg3d/_tests/test_plugin_utils.py | 2 +- napari_cellseg3d/code_models/worker_inference.py | 2 +- napari_cellseg3d/config.py | 4 +--- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py index de518b3d..03565045 100644 --- a/napari_cellseg3d/_tests/test_plugin_inference.py +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -31,7 +31,7 @@ def test_inference(make_napari_viewer_proxy, qtbot): assert widget.check_ready() - widget.model_choice.setCurrentText("WNet") + widget.model_choice.setCurrentText("WNet3D") widget._restrict_window_size_for_model() assert widget.use_window_choice.isChecked() assert widget.window_size_choice.currentText() == "64" diff --git a/napari_cellseg3d/_tests/test_plugin_training.py b/napari_cellseg3d/_tests/test_plugin_training.py index 41102bbd..7ea61200 100644 --- a/napari_cellseg3d/_tests/test_plugin_training.py +++ b/napari_cellseg3d/_tests/test_plugin_training.py @@ -33,7 +33,7 @@ def test_worker_configs(make_napari_viewer_proxy): worker.config, attr ) # test unsupervised config and worker - widget.model_choice.setCurrentText("WNet") + widget.model_choice.setCurrentText("WNet3D") widget._toggle_unsupervised_mode(enabled=True) default_config = config.WNetTrainingWorkerConfig() worker = widget._create_worker(additional_results_description="TEST_1") diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 9e668fc5..3f3da7b3 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -38,7 +38,7 @@ def test_crop_widget(make_napari_viewer_proxy): view = make_napari_viewer_proxy() widget = Cropping(view) - image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image = rand_gen.random((10, 10, 10)).astype(np.int8) image_layer_1 = view.add_image(image, name="image") image_layer_2 = view.add_labels(image, name="image2") diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index c69dfe45..46ba77eb 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -277,7 +277,7 @@ def load_layer(self): normalization = ( QuantileNormalization() - if self.config.model_info.name != "WNet" + if self.config.model_info.name != "WNet3D" else lambda x: x ) volume = np.reshape(volume, newshape=(1, *volume.shape)) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py index 7f8baf05..f319802d 100644 --- a/napari_cellseg3d/config.py +++ b/napari_cellseg3d/config.py @@ -26,7 +26,7 @@ "VNet": VNet_, "TRAILMAP_MS": TRAILMAP_MS_, "SwinUNetR": SwinUNETR_, - "WNet": WNet_, + "WNet3D": WNet_, # "TRAILMAP": TRAILMAP, # "test" : DO NOT USE, reserved for testing } @@ -348,13 +348,11 @@ class SupervisedTrainingWorkerConfig(TrainingWorkerConfig): """Class to record config for Trainer plugin. Args: - eval_data_dict (dict): dict of eval data as {"image": np.array, "labels": np.array}. Optional. model_info (ModelInfo): model info loss_function (callable): loss function validation_percent (float): validation percent """ - eval_data_dict: dict = None model_info: ModelInfo = None loss_function: callable = None training_percent: float = 0.8