From 726cc1d19bf80ab7977bef263eb28236742afb43 Mon Sep 17 00:00:00 2001 From: Andres Date: Tue, 14 Sep 2021 03:24:28 +0100 Subject: [PATCH 01/16] Add Epistemic strategy to DeepEdit App Signed-off-by: Andres --- monailabel/tasks/scoring/epistemic.py | 6 ++--- sample-apps/deepedit/main.py | 36 +++++++++++++++++++-------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/monailabel/tasks/scoring/epistemic.py b/monailabel/tasks/scoring/epistemic.py index 6bea0ace5..101436d50 100644 --- a/monailabel/tasks/scoring/epistemic.py +++ b/monailabel/tasks/scoring/epistemic.py @@ -26,11 +26,11 @@ class EpistemicScoring(ScoringMethod): """ - First version of test time augmentation active learning + First version of Epistemic computation used as active learning strategy """ def __init__(self, model, network=None, transforms=None, roi_size=(128, 128, 64), num_samples=10): - super().__init__("Compute initial score based on TTA") + super().__init__("Compute initial score based on dropout") self.model = model self.network = network self.transforms = transforms @@ -155,7 +155,7 @@ def __call__(self, request, datastore: Datastore): accum_numpy = np.stack(accum_unl_outputs) accum_numpy = np.squeeze(accum_numpy) - accum_numpy = accum_numpy[:, 1, :, :, :] + accum_numpy = accum_numpy[:, 1, :, :, :] if len(accum_numpy.shape) > 4 else accum_numpy entropy = self.entropy_3d_volume(accum_numpy) entropy_sum = float(np.sum(entropy)) diff --git a/sample-apps/deepedit/main.py b/sample-apps/deepedit/main.py index 7d3823194..ad8c6f071 100644 --- a/sample-apps/deepedit/main.py +++ b/sample-apps/deepedit/main.py @@ -28,6 +28,7 @@ from monailabel.tasks.activelearning.random import Random from monailabel.tasks.activelearning.tta import TTA from monailabel.tasks.scoring.dice import Dice +from monailabel.tasks.scoring.epistemic import EpistemicScoring from monailabel.tasks.scoring.sum import Sum from monailabel.tasks.scoring.tta import TTAScoring from monailabel.utils.others.planner import HeuristicPlanner @@ -37,11 +38,11 @@ class MyApp(MONAILabelApp): def __init__(self, app_dir, studies, conf): - self.network = DynUNetV1( - spatial_dims=3, - in_channels=3, - out_channels=1, - kernel_size=[ + network_params = { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 1, + "kernel_size": [ [3, 3, 3], [3, 3, 3], [3, 3, 3], @@ -49,7 +50,7 @@ def __init__(self, app_dir, studies, conf): [3, 3, 3], [3, 3, 3], ], - strides=[ + "strides": [ [1, 1, 1], [2, 2, 2], [2, 2, 2], @@ -57,17 +58,19 @@ def __init__(self, app_dir, studies, conf): [2, 2, 2], [2, 2, 1], ], - upsample_kernel_size=[ + "upsample_kernel_size": [ [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 1], ], - norm_name="instance", - deep_supervision=False, - res_block=True, - ) + "norm_name": "instance", + "deep_supervision": False, + "res_block": True, + } + self.network = DynUNetV1(**network_params) + self.network_with_dropout = DynUNetV1(**network_params, dropout=0.2) self.model_dir = os.path.join(app_dir, "model") self.pretrained_model = os.path.join(self.model_dir, "pretrained.pt") @@ -86,6 +89,10 @@ def __init__(self, app_dir, studies, conf): if use_pretrained_model: self.download([(self.pretrained_model, pretrained_model_uri)]) + self.epistemic_enabled = strtobool(conf.get("epistemic_enabled", "false")) + self.epistemic_samples = int(conf.get("epistemic_samples", "5")) + logger.info(f"EPISTEMIC Enabled: {self.epistemic_enabled}; Samples: {self.epistemic_samples}") + self.tta_enabled = strtobool(conf.get("tta_enabled", "false")) self.tta_samples = int(conf.get("tta_samples", "5")) logger.info(f"TTA Enabled: {self.tta_enabled}; Samples: {self.tta_samples}") @@ -145,6 +152,13 @@ def init_strategies(self) -> Dict[str, Strategy]: def init_scoring_methods(self) -> Dict[str, ScoringMethod]: methods: Dict[str, ScoringMethod] = {} + if self.epistemic_enabled: + methods["EPISTEMIC"] = EpistemicScoring( + model=[self.pretrained_model, self.final_model], + network=self.network_with_dropout, + transforms=self._infers["deepedit_seg"].pre_transforms(), + num_samples=self.epistemic_samples, + ) if self.tta_enabled: methods["TTA"] = TTAScoring( model=[self.pretrained_model, self.final_model], From 20c8eec81d59ff54e355fdac367a0151db9b8a48 Mon Sep 17 00:00:00 2001 From: Andres Date: Tue, 14 Sep 2021 03:29:08 +0100 Subject: [PATCH 02/16] Update link to DeepEdit App Signed-off-by: Andres --- sample-apps/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sample-apps/README.md b/sample-apps/README.md index 3ab5e42a6..92fb6e37f 100644 --- a/sample-apps/README.md +++ b/sample-apps/README.md @@ -12,7 +12,7 @@ the [deepgrow App](./deepgrow). The latter one is meant for the users that want #### DeepEdit -Similar to the deepgrow Apps, you'll find the one generic [deepedit](./generic_deepedit) that researchers can use to build their own deepedit-based app. +Similar to the deepgrow Apps, you'll find the one generic [deepedit](./deepedit) that researchers can use to build their own deepedit-based app. #### Automated Segmentation From 480d2e72e86c2191ca6ca3851d9d2b3df714b260 Mon Sep 17 00:00:00 2001 From: Andres Date: Tue, 14 Sep 2021 13:08:11 +0100 Subject: [PATCH 03/16] Update init strategies method Signed-off-by: Andres --- sample-apps/deepedit/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sample-apps/deepedit/main.py b/sample-apps/deepedit/main.py index ad8c6f071..04543a8e7 100644 --- a/sample-apps/deepedit/main.py +++ b/sample-apps/deepedit/main.py @@ -25,6 +25,7 @@ from monailabel.interfaces.tasks.strategy import Strategy from monailabel.interfaces.tasks.train import TrainTask from monailabel.scribbles.infer import HistogramBasedGraphCut +from monailabel.tasks.activelearning.epistemic import Epistemic from monailabel.tasks.activelearning.random import Random from monailabel.tasks.activelearning.tta import TTA from monailabel.tasks.scoring.dice import Dice @@ -144,6 +145,8 @@ def init_trainers(self) -> Dict[str, TrainTask]: def init_strategies(self) -> Dict[str, Strategy]: strategies: Dict[str, Strategy] = {} + if self.epistemic_enabled: + strategies["EPISTEMIC"] = Epistemic() if self.tta_enabled: strategies["TTA"] = TTA() strategies["random"] = Random() From efa94767d05734d4a0056eda508c33b280b00f9a Mon Sep 17 00:00:00 2001 From: Andres Date: Tue, 14 Sep 2021 13:46:47 +0100 Subject: [PATCH 04/16] Add conditional to preds - Sigmoid Signed-off-by: Andres --- monailabel/tasks/scoring/epistemic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monailabel/tasks/scoring/epistemic.py b/monailabel/tasks/scoring/epistemic.py index 101436d50..e412edb0b 100644 --- a/monailabel/tasks/scoring/epistemic.py +++ b/monailabel/tasks/scoring/epistemic.py @@ -54,7 +54,7 @@ def infer_seg(self, data, model, roi_size, sw_batch_size): inputs=data["image"][None].cuda(), roi_size=roi_size, sw_batch_size=sw_batch_size, predictor=model ) - soft_preds = torch.softmax(preds, dim=1) + soft_preds = torch.softmax(preds, dim=1) if preds.shape[1] > 1 else torch.sigmoid(preds) soft_preds = soft_preds.detach().to("cpu").numpy() return soft_preds From e03f6a18f8d0648dfb6bb1fcd477fde049c454a1 Mon Sep 17 00:00:00 2001 From: Andres Date: Tue, 14 Sep 2021 15:47:35 +0100 Subject: [PATCH 05/16] Add train mode Epistemic - eval mode TTA Signed-off-by: Andres --- monailabel/tasks/scoring/epistemic.py | 4 ++-- monailabel/tasks/scoring/tta.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monailabel/tasks/scoring/epistemic.py b/monailabel/tasks/scoring/epistemic.py index e412edb0b..31a942cdd 100644 --- a/monailabel/tasks/scoring/epistemic.py +++ b/monailabel/tasks/scoring/epistemic.py @@ -125,7 +125,7 @@ def __call__(self, request, datastore: Datastore): model, model_ts = self._load_model(self.model, self.network) if not model: return - model = model.to(self.device) + model = model.to(self.device).train() # Performing Epistemic for all unlabeled images skipped = 0 @@ -155,7 +155,7 @@ def __call__(self, request, datastore: Datastore): accum_numpy = np.stack(accum_unl_outputs) accum_numpy = np.squeeze(accum_numpy) - accum_numpy = accum_numpy[:, 1, :, :, :] if len(accum_numpy.shape) > 4 else accum_numpy + accum_numpy = accum_numpy[:, 1:, :, :, :] if len(accum_numpy.shape) > 4 else accum_numpy entropy = self.entropy_3d_volume(accum_numpy) entropy_sum = float(np.sum(entropy)) diff --git a/monailabel/tasks/scoring/tta.py b/monailabel/tasks/scoring/tta.py index d2ef1a468..604255ddc 100644 --- a/monailabel/tasks/scoring/tta.py +++ b/monailabel/tasks/scoring/tta.py @@ -141,7 +141,7 @@ def __call__(self, request, datastore: Datastore): model, model_ts = self._load_model(self.model, self.network) if not model: return - model = model.to(self.device) + model = model.to(self.device).eval() tt_aug = TestTimeAugmentation( transform=self.pre_transforms(), From 2a23ccad9d3dead501d4fb245ef2eabac06507b0 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 14:32:42 +0100 Subject: [PATCH 06/16] Update requirements - temporarily Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 19c814ddd..e98b9b48a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 -monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] +monai[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide]=="git+https://github.com/Project-MONAI/MONAI.git@dev" pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From de5330425b488349a2768d98fb6da7a828f0b364 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 14:36:11 +0100 Subject: [PATCH 07/16] Update requirements - temporarily - github Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e98b9b48a..e0227b0bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 -monai[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide]=="git+https://github.com/Project-MONAI/MONAI.git@dev" +git+https://github.com/Project-MONAI/MONAI.git@dev pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From 7ca43e60ee98e86a594130749ae814424f9f5110 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 14:46:28 +0100 Subject: [PATCH 08/16] Update requirements Signed-off-by: Andres --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index e0227b0bb..fc6baacc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 +monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] git+https://github.com/Project-MONAI/MONAI.git@dev pyyaml==5.4.1 python-multipart==0.0.5 From 1828a35512dc7ee86f290c0799b02b38b4f0f566 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 14:48:17 +0100 Subject: [PATCH 09/16] Update requirements Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fc6baacc5..0269da9b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiofiles==0.6.0 fastapi==0.65.2 monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] -git+https://github.com/Project-MONAI/MONAI.git@dev +https://github.com/Project-MONAI/MONAI.git@dev pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From 0edc2cef1027706602ee78983c45493b558266f8 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 14:54:08 +0100 Subject: [PATCH 10/16] Update requirements - 4 Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0269da9b4..aa373cc34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiofiles==0.6.0 fastapi==0.65.2 monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] -https://github.com/Project-MONAI/MONAI.git@dev +https://github.com/Project-MONAI/MONAI/tree/dev#egg=monai070 pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From e660a4b4a07d53585219e9dc2cd730baa702b1ea Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 14:59:25 +0100 Subject: [PATCH 11/16] Update requirements - 5 Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index aa373cc34..d0430ba74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiofiles==0.6.0 fastapi==0.65.2 monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] -https://github.com/Project-MONAI/MONAI/tree/dev#egg=monai070 +https://github.com/diazandr3s/MONAI/archive/refs/tags/0.6.0.zip pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From aa993623a9d0cc70a14ba0b20e5cce012b814a09 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 15:04:18 +0100 Subject: [PATCH 12/16] Update requirements - 6 Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d0430ba74..d0ed8c678 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiofiles==0.6.0 fastapi==0.65.2 +https://github.com/diazandr3s/MONAI/archive/refs/tags/0.7.0-tempAndres.tar.gz#egg=monai monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] -https://github.com/diazandr3s/MONAI/archive/refs/tags/0.6.0.zip pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From 7a6fb57b3d9f95aa36b473672a3c0e83666b9ff8 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 15:09:39 +0100 Subject: [PATCH 13/16] Update requirements - 7 Signed-off-by: Andres --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d0ed8c678..34f4b8cdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 -https://github.com/diazandr3s/MONAI/archive/refs/tags/0.7.0-tempAndres.tar.gz#egg=monai +https://github.com/diazandr3s/MONAI/archive/refs/tags/v0.7.0.tar.gz#egg=monai monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] pyyaml==5.4.1 python-multipart==0.0.5 From cdb2702a70237a746cd991e2c5faf17d05e9e940 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 15:18:43 +0100 Subject: [PATCH 14/16] Update requirements - 8 Signed-off-by: Andres --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 34f4b8cdf..03a2db2ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 https://github.com/diazandr3s/MONAI/archive/refs/tags/v0.7.0.tar.gz#egg=monai -monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1 From 97764dbccb233be285e1e08bf4b6880be3d38e74 Mon Sep 17 00:00:00 2001 From: Andres Date: Wed, 15 Sep 2021 15:33:17 +0100 Subject: [PATCH 15/16] Update requirements - 9 Signed-off-by: Andres --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 03a2db2ca..aeda87bf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 +monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] https://github.com/diazandr3s/MONAI/archive/refs/tags/v0.7.0.tar.gz#egg=monai pyyaml==5.4.1 python-multipart==0.0.5 From ef167890ff83d69db624e43507b6ea6cfaf28d16 Mon Sep 17 00:00:00 2001 From: Andres Date: Thu, 16 Sep 2021 00:33:20 +0100 Subject: [PATCH 16/16] Update requirements monai 0.7.0rc1 Signed-off-by: Andres --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index aeda87bf6..940e912d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ aiofiles==0.6.0 fastapi==0.65.2 -monai-weekly[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide] -https://github.com/diazandr3s/MONAI/archive/refs/tags/v0.7.0.tar.gz#egg=monai +monai[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, openslide]==0.7.0rc1 pyyaml==5.4.1 python-multipart==0.0.5 requests-toolbelt==0.9.1