From d7b910fe4252937360e95fce9197b074da28e2b9 Mon Sep 17 00:00:00 2001 From: am Date: Thu, 1 Aug 2024 19:11:39 -0700 Subject: [PATCH 01/11] cell_sam_wrapper net Signed-off-by: am --- monai/networks/nets/cell_sam_wrapper.py | 87 +++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 monai/networks/nets/cell_sam_wrapper.py diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py new file mode 100644 index 0000000000..3da0e0b5e4 --- /dev/null +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import functional as F + +try: + from segment_anything.build_sam import build_sam_vit_b +except ImportError: + raise AssertionError("SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git") + + +_all__ = ["CellSamWrapper"] + +class CellSamWrapper(torch.nn.Module): + """ + CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything + with an image only decoder, that can be used for segmentation tasks. + + + Args: + auto_resize_inputs: whether to resize inputs before passing to the network. + network_resize_roi: expected input size for the network. + checkpoint: checkpoint file to load the SAM weights from. + return_features: whether to return features + + """ + + def __init__( + self, + auto_resize_inputs=True, + network_resize_roi=[1024, 1024], + checkpoint="sam_vit_b_01ec64.pth", + return_features=False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + print( + f"CellSamWrapper auto_resize_inputs {auto_resize_inputs} network_resize_roi {network_resize_roi} checkpoint {checkpoint}" + ) + self.network_resize_roi = network_resize_roi + self.auto_resize_inputs = auto_resize_inputs + self.return_features = return_features + + model = build_sam_vit_b(checkpoint=checkpoint) + + model.prompt_encoder = None + model.mask_decoder = None + + model.mask_decoder = nn.Sequential( + nn.BatchNorm2d(num_features=256), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), + nn.BatchNorm2d(num_features=128), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + ) + + self.model = model + + def forward(self, x): + sh = x.shape[2:] + + if self.auto_resize_inputs: + x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") + + x = self.model.image_encoder(x) + + if not self.return_features: + x = self.model.mask_decoder(x) + if self.auto_resize_inputs: + x = F.interpolate(x, size=sh, mode="bilinear") + + return x From 51d63c91185a0b7609212d37b142e73ec5312825 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 02:15:40 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: am --- monai/networks/nets/cell_sam_wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index 3da0e0b5e4..ccca07dee2 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -33,10 +33,10 @@ class CellSamWrapper(torch.nn.Module): auto_resize_inputs: whether to resize inputs before passing to the network. network_resize_roi: expected input size for the network. checkpoint: checkpoint file to load the SAM weights from. - return_features: whether to return features + return_features: whether to return features """ - + def __init__( self, auto_resize_inputs=True, @@ -77,7 +77,7 @@ def forward(self, x): if self.auto_resize_inputs: x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") - x = self.model.image_encoder(x) + x = self.model.image_encoder(x) if not self.return_features: x = self.model.mask_decoder(x) From 0c605db9a5f7db9dcae0cdf1d7ab1f0ab779ac65 Mon Sep 17 00:00:00 2001 From: myron Date: Mon, 5 Aug 2024 22:44:01 -0700 Subject: [PATCH 03/11] unit test Signed-off-by: myron --- monai/networks/nets/cell_sam_wrapper.py | 13 ++--- tests/test_cell_sam_wrapper.py | 63 +++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 tests/test_cell_sam_wrapper.py diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index ccca07dee2..b721756461 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -15,11 +15,8 @@ from torch import nn from torch.nn import functional as F -try: - from segment_anything.build_sam import build_sam_vit_b -except ImportError: - raise AssertionError("SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git") - +from monai.utils import optional_import +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") _all__ = ["CellSamWrapper"] @@ -48,13 +45,13 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) - print( - f"CellSamWrapper auto_resize_inputs {auto_resize_inputs} network_resize_roi {network_resize_roi} checkpoint {checkpoint}" - ) self.network_resize_roi = network_resize_roi self.auto_resize_inputs = auto_resize_inputs self.return_features = return_features + if not has_sam: + raise ValueError("SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git") + model = build_sam_vit_b(checkpoint=checkpoint) model.prompt_encoder = None diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py new file mode 100644 index 0000000000..fa88b311d3 --- /dev/null +++ b/tests/test_cell_sam_wrapper.py @@ -0,0 +1,63 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.cell_sam_wrapper import CellSamWrapper +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save + +from monai.utils import optional_import +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +device = "cuda" if torch.cuda.is_available() else "cpu" +TEST_CASE_CELLSEGWRAPPER = [] +for auto_resize_inputs in [True, False]: + for dims in [128, 256, 512, 1024]: + test_case = [ + { + "auto_resize_inputs": True, + "network_resize_roi": [1024, 1024], + "checkpoint": None, + }, + (1, 3, *([dims] * 2)), + (1, 3, *([dims] * 2)), + ] + TEST_CASE_CELLSEGWRAPPER.append(test_case) + + +@unittest.skipUnless(has_sam, "Requires SAM installation") +class TestResNetDS(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CELLSEGWRAPPER) + def test_shape(self, input_param, input_shape, expected_shape): + net = CellSamWrapper(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg0(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device) + net(torch.randn([1, 3, 256, 256])) + + def test_ill_arg1(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(network_resize_roi=[256,256], checkpoint=None).to(device) + net(torch.randn([1, 3, 1024, 1024])) + +if __name__ == "__main__": + unittest.main() From 14cae5f5e6514e7f9231b3ecec8400f0eda86ac6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 05:45:06 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/cell_sam_wrapper.py | 2 +- tests/test_cell_sam_wrapper.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index b721756461..3bc5db2414 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -51,7 +51,7 @@ def __init__( if not has_sam: raise ValueError("SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git") - + model = build_sam_vit_b(checkpoint=checkpoint) model.prompt_encoder = None diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py index fa88b311d3..867f8e7d20 100644 --- a/tests/test_cell_sam_wrapper.py +++ b/tests/test_cell_sam_wrapper.py @@ -18,7 +18,6 @@ from monai.networks import eval_mode from monai.networks.nets.cell_sam_wrapper import CellSamWrapper -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save from monai.utils import optional_import build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") @@ -48,7 +47,7 @@ def test_shape(self, input_param, input_shape, expected_shape): with eval_mode(net): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape, msg=str(input_param)) - + def test_ill_arg0(self): with self.assertRaises(RuntimeError): net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device) From d7d624b9fdf57a2e231a1632e9a9555c80bb2822 Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 9 Aug 2024 10:49:13 -0700 Subject: [PATCH 05/11] edits Signed-off-by: myron --- monai/networks/nets/cell_sam_wrapper.py | 6 +++++- requirements-dev.txt | 1 + tests/test_cell_sam_wrapper.py | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index 3bc5db2414..250146554f 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -28,9 +28,13 @@ class CellSamWrapper(torch.nn.Module): Args: auto_resize_inputs: whether to resize inputs before passing to the network. + (usually they need be resized, unless they are already at the expected size) network_resize_roi: expected input size for the network. + (currently SAM expects 1024x1024) checkpoint: checkpoint file to load the SAM weights from. - return_features: whether to return features + (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) + return_features: whether to return features from SAM encoder + (without using decoder/upsampling to the original input size) """ diff --git a/requirements-dev.txt b/requirements-dev.txt index 72ba210093..76f1952345 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,4 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd +git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py index 867f8e7d20..d2be6880bf 100644 --- a/tests/test_cell_sam_wrapper.py +++ b/tests/test_cell_sam_wrapper.py @@ -51,12 +51,12 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_ill_arg0(self): with self.assertRaises(RuntimeError): net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device) - net(torch.randn([1, 3, 256, 256])) + net(torch.randn([1, 3, 256, 256]).to(device)) def test_ill_arg1(self): with self.assertRaises(RuntimeError): net = CellSamWrapper(network_resize_roi=[256,256], checkpoint=None).to(device) - net(torch.randn([1, 3, 1024, 1024])) + net(torch.randn([1, 3, 1024, 1024]).to(device)) if __name__ == "__main__": unittest.main() From a41a615962bded7969abb6f06cf42aad7d0ffe26 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:51:48 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/cell_sam_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index 250146554f..b849d5f259 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -33,7 +33,7 @@ class CellSamWrapper(torch.nn.Module): (currently SAM expects 1024x1024) checkpoint: checkpoint file to load the SAM weights from. (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) - return_features: whether to return features from SAM encoder + return_features: whether to return features from SAM encoder (without using decoder/upsampling to the original input size) """ From 24afcb46887a054feabaa320649258940fbaa5b8 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 10 Aug 2024 16:57:23 +0800 Subject: [PATCH 07/11] format fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/cell_sam_wrapper.py | 6 +++++- tests/test_cell_sam_wrapper.py | 21 +++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index b849d5f259..c984929147 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -16,10 +16,12 @@ from torch.nn import functional as F from monai.utils import optional_import + build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") _all__ = ["CellSamWrapper"] + class CellSamWrapper(torch.nn.Module): """ CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything @@ -54,7 +56,9 @@ def __init__( self.return_features = return_features if not has_sam: - raise ValueError("SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git") + raise ValueError( + "SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git" + ) model = build_sam_vit_b(checkpoint=checkpoint) diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py index d2be6880bf..d0f3b1bc9e 100644 --- a/tests/test_cell_sam_wrapper.py +++ b/tests/test_cell_sam_wrapper.py @@ -18,24 +18,20 @@ from monai.networks import eval_mode from monai.networks.nets.cell_sam_wrapper import CellSamWrapper - from monai.utils import optional_import + build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_CELLSEGWRAPPER = [] for auto_resize_inputs in [True, False]: for dims in [128, 256, 512, 1024]: - test_case = [ - { - "auto_resize_inputs": True, - "network_resize_roi": [1024, 1024], - "checkpoint": None, - }, - (1, 3, *([dims] * 2)), - (1, 3, *([dims] * 2)), - ] - TEST_CASE_CELLSEGWRAPPER.append(test_case) + test_case = [ + {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None}, + (1, 3, *([dims] * 2)), + (1, 3, *([dims] * 2)), + ] + TEST_CASE_CELLSEGWRAPPER.append(test_case) @unittest.skipUnless(has_sam, "Requires SAM installation") @@ -55,8 +51,9 @@ def test_ill_arg0(self): def test_ill_arg1(self): with self.assertRaises(RuntimeError): - net = CellSamWrapper(network_resize_roi=[256,256], checkpoint=None).to(device) + net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device) net(torch.randn([1, 3, 1024, 1024]).to(device)) + if __name__ == "__main__": unittest.main() From 8ee636fb900d2df4b398e4f2509a391dabbdce48 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 10 Aug 2024 17:22:19 +0800 Subject: [PATCH 08/11] fix format issue Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/cell_sam_wrapper.py | 2 +- tests/test_cell_sam_wrapper.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py index c984929147..308c3a6bcb 100644 --- a/monai/networks/nets/cell_sam_wrapper.py +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -43,7 +43,7 @@ class CellSamWrapper(torch.nn.Module): def __init__( self, auto_resize_inputs=True, - network_resize_roi=[1024, 1024], + network_resize_roi=(1024, 1024), checkpoint="sam_vit_b_01ec64.pth", return_features=False, *args, diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py index d0f3b1bc9e..2f1ee2b901 100644 --- a/tests/test_cell_sam_wrapper.py +++ b/tests/test_cell_sam_wrapper.py @@ -24,14 +24,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_CELLSEGWRAPPER = [] -for auto_resize_inputs in [True, False]: - for dims in [128, 256, 512, 1024]: - test_case = [ - {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None}, - (1, 3, *([dims] * 2)), - (1, 3, *([dims] * 2)), - ] - TEST_CASE_CELLSEGWRAPPER.append(test_case) +for dims in [128, 256, 512, 1024]: + test_case = [ + {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None}, + (1, 3, *([dims] * 2)), + (1, 3, *([dims] * 2)), + ] + TEST_CASE_CELLSEGWRAPPER.append(test_case) @unittest.skipUnless(has_sam, "Requires SAM installation") From b5c79f1b5924da134008d17e9950f50342c20beb Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 10 Aug 2024 17:42:37 +0800 Subject: [PATCH 09/11] try add it in setup.cfg Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index dfa94fcfa1..657039c71a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,6 +85,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0 + segment-anything nibabel = nibabel ninja = @@ -167,6 +168,8 @@ huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0 +segment-anything = + git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 [flake8] select = B,C,E,F,N,P,T4,W,B9 From 6d4a047e20da8c93a98529bf541dc3a05c75b122 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 10 Aug 2024 17:47:00 +0800 Subject: [PATCH 10/11] add it in setup.cfg Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 657039c71a..e240445e36 100644 --- a/setup.cfg +++ b/setup.cfg @@ -163,13 +163,13 @@ pynvml = nvidia-ml-py # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = -# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded + # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0 segment-anything = - git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 + segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything [flake8] select = B,C,E,F,N,P,T4,W,B9 From 53ac8e1e38e2bf94cf01c776ae5b762cc68b24ba Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Sat, 10 Aug 2024 17:51:43 +0800 Subject: [PATCH 11/11] update installation md Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/installation.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index 4308a07647..70a8b6f1d4 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, segment-anything] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `segment-anything` respectively. - `pip install 'monai[all]'` installs all the optional dependencies.