From 7f255ad8f6aa8579c99dac626110c2b2b12ccc55 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Mon, 3 Apr 2023 17:58:14 -0700 Subject: [PATCH 01/23] Add text to vision embedding Signed-off-by: tangy5 --- monai/networks/blocks/text_enbedding.py | 60 +++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 monai/networks/blocks/text_enbedding.py diff --git a/monai/networks/blocks/text_enbedding.py b/monai/networks/blocks/text_enbedding.py new file mode 100644 index 0000000000..290a80ff8e --- /dev/null +++ b/monai/networks/blocks/text_enbedding.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn + +class TextEncoder(nn.Module): + """ + Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding. + The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models. + + Contrastive Language-Image Pre-training (CLIP), based on: "Radford et al., + Learning Transferable Visual Models From Natural Language Supervision " + + Connecting text and medical 3D image, based on: "Liu et al., + CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " + """ + def __init__( + self, + out_channels: int, + text_dim: int = 512, + hidden_size: int = 256, + encoding: str = "clip_embedding", + ) -> None: + """ + Args: + out_channels: number of output channels, to control text-baesd embedding for classes. + text_dim: dimension of text embeddings. + hidden_size: dimension of hidden features, compatible to different vision feature dimensions. + encoding: the text embedding type, default to use clip text pretrained weights + """ + self.encoding = encoding + + if self.encoding == 'rand_embedding': + self.organ_embedding = nn.Embedding(out_channels, hidden_size) + elif self.encoding == 'clip_embedding': + self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) + self.text_to_vision = nn.Linear(text_dim, hidden_size) + + def forward(self): + if self.encoding == 'clip_embedding': + task_encoding = nn.function.relu(self.text_to_vision(self.organ_embedding)) + task_encoding = task_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) + else: + # text embedding as random initialized 'rand_embedding' + task_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) + return task_encoding + + + From 2dd5566422815539ab4d599195f620ff2beabaef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Apr 2023 01:01:39 +0000 Subject: [PATCH 02/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/text_enbedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/networks/blocks/text_enbedding.py b/monai/networks/blocks/text_enbedding.py index 290a80ff8e..9028a3cc30 100644 --- a/monai/networks/blocks/text_enbedding.py +++ b/monai/networks/blocks/text_enbedding.py @@ -24,7 +24,7 @@ class TextEncoder(nn.Module): Connecting text and medical 3D image, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, out_channels: int, @@ -55,6 +55,3 @@ def forward(self): # text embedding as random initialized 'rand_embedding' task_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) return task_encoding - - - From 81c29656e54322d1fad9b9629384bf71040b6a25 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 16:52:45 -0700 Subject: [PATCH 03/23] update parameters Signed-off-by: tangy5 --- monai/networks/blocks/text_enbedding.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) mode change 100644 => 100755 monai/networks/blocks/text_enbedding.py diff --git a/monai/networks/blocks/text_enbedding.py b/monai/networks/blocks/text_enbedding.py old mode 100644 new mode 100755 index 9028a3cc30..a1f69dd457 --- a/monai/networks/blocks/text_enbedding.py +++ b/monai/networks/blocks/text_enbedding.py @@ -24,7 +24,7 @@ class TextEncoder(nn.Module): Connecting text and medical 3D image, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, out_channels: int, @@ -39,6 +39,7 @@ def __init__( hidden_size: dimension of hidden features, compatible to different vision feature dimensions. encoding: the text embedding type, default to use clip text pretrained weights """ + super().__init__() self.encoding = encoding if self.encoding == 'rand_embedding': @@ -49,9 +50,12 @@ def __init__( def forward(self): if self.encoding == 'clip_embedding': - task_encoding = nn.function.relu(self.text_to_vision(self.organ_embedding)) + task_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) task_encoding = task_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) else: # text embedding as random initialized 'rand_embedding' task_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) return task_encoding + + + From f84ac5d06a391122131728ec8a35266926e16783 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:56:25 +0000 Subject: [PATCH 04/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/text_enbedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/networks/blocks/text_enbedding.py b/monai/networks/blocks/text_enbedding.py index a1f69dd457..1f1a6e7144 100755 --- a/monai/networks/blocks/text_enbedding.py +++ b/monai/networks/blocks/text_enbedding.py @@ -24,7 +24,7 @@ class TextEncoder(nn.Module): Connecting text and medical 3D image, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, out_channels: int, @@ -56,6 +56,3 @@ def forward(self): # text embedding as random initialized 'rand_embedding' task_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) return task_encoding - - - From d737121b6363b7e5a5e4a774313b1626bd301031 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 21:58:11 -0700 Subject: [PATCH 05/23] update encoding Signed-off-by: tangy5 --- .../networks/blocks/{text_enbedding.py => text_embedding.py} | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) rename monai/networks/blocks/{text_enbedding.py => text_embedding.py} (99%) diff --git a/monai/networks/blocks/text_enbedding.py b/monai/networks/blocks/text_embedding.py similarity index 99% rename from monai/networks/blocks/text_enbedding.py rename to monai/networks/blocks/text_embedding.py index 1f1a6e7144..a1f69dd457 100755 --- a/monai/networks/blocks/text_enbedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -24,7 +24,7 @@ class TextEncoder(nn.Module): Connecting text and medical 3D image, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, out_channels: int, @@ -56,3 +56,6 @@ def forward(self): # text embedding as random initialized 'rand_embedding' task_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) return task_encoding + + + From ae7c2fe1b10dafa81252483231aa1f8f691e9e7a Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 22:04:24 -0700 Subject: [PATCH 06/23] change file mode Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) mode change 100755 => 100644 monai/networks/blocks/text_embedding.py diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py old mode 100755 new mode 100644 index a1f69dd457..576373d0ad --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -50,12 +50,12 @@ def __init__( def forward(self): if self.encoding == 'clip_embedding': - task_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) - task_encoding = task_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) + test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) + test_encoding = test_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) else: # text embedding as random initialized 'rand_embedding' - task_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) - return task_encoding + test_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) + return test_encoding From bd4dc378ec030f0d25726e440d8e425efeb2ad11 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 22:16:20 -0700 Subject: [PATCH 07/23] fix flake8 format Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 576373d0ad..84f670ec73 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -12,7 +12,7 @@ from __future__ import annotations import torch -import torch.nn as nn +from torch import nn class TextEncoder(nn.Module): """ From 3fd70e32dde0bc3120b86c69e308de2657af55af Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 22:24:55 -0700 Subject: [PATCH 08/23] fix flake8 format2 Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 84f670ec73..1f4588ea20 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -14,6 +14,7 @@ import torch from torch import nn + class TextEncoder(nn.Module): """ Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding. From a65d6c18d9aaa8bf226428b206eb55bdef632a5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Apr 2023 05:25:34 +0000 Subject: [PATCH 09/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/text_embedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 1f4588ea20..ae4ba71c7e 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -25,7 +25,7 @@ class TextEncoder(nn.Module): Connecting text and medical 3D image, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, out_channels: int, @@ -57,6 +57,3 @@ def forward(self): # text embedding as random initialized 'rand_embedding' test_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) return test_encoding - - - From 79a8f853f79db1c77312afb414ba1435847dfa09 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Thu, 6 Apr 2023 15:55:10 -0700 Subject: [PATCH 10/23] update var name Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index ae4ba71c7e..e821acb558 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -44,7 +44,7 @@ def __init__( self.encoding = encoding if self.encoding == 'rand_embedding': - self.organ_embedding = nn.Embedding(out_channels, hidden_size) + self.text_embedding = nn.Embedding(out_channels, hidden_size) elif self.encoding == 'clip_embedding': self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) self.text_to_vision = nn.Linear(text_dim, hidden_size) From 2cc0ce4c0a532465ecef90729f3382d8adf3a2f3 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Thu, 6 Apr 2023 16:03:50 -0700 Subject: [PATCH 11/23] update var name Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index e821acb558..a1713b4f07 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -48,6 +48,8 @@ def __init__( elif self.encoding == 'clip_embedding': self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) self.text_to_vision = nn.Linear(text_dim, hidden_size) + else: + raise Exception('{} is not implemented, please add your own'.format(self.encoding)) def forward(self): if self.encoding == 'clip_embedding': From 88f392fa806f49296f30ee91c41ef8cd11c6ca2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Apr 2023 23:04:24 +0000 Subject: [PATCH 12/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index a1713b4f07..4b8eccd6e3 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -49,7 +49,7 @@ def __init__( self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) self.text_to_vision = nn.Linear(text_dim, hidden_size) else: - raise Exception('{} is not implemented, please add your own'.format(self.encoding)) + raise Exception(f'{self.encoding} is not implemented, please add your own') def forward(self): if self.encoding == 'clip_embedding': From fefa8d13940a67a0cd3e04bf1f8f0b1555a47569 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Wed, 12 Apr 2023 12:10:28 -0700 Subject: [PATCH 13/23] update 2d case, pretrain option, release CLIP weights Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 29 +++++++++++++++-- tests/test_text_encoding.py | 43 +++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 tests/test_text_encoding.py diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 4b8eccd6e3..50a8bf73f0 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -13,6 +13,11 @@ import torch from torch import nn +from torch.utils import model_zoo + +url_map = { + "clip_encoding_univeral_model_31": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth", +} class TextEncoder(nn.Module): @@ -29,24 +34,37 @@ class TextEncoder(nn.Module): def __init__( self, out_channels: int, + spatial_dims: int = 3, text_dim: int = 512, hidden_size: int = 256, encoding: str = "clip_embedding", + pretrained: bool = False ) -> None: """ Args: out_channels: number of output channels, to control text-baesd embedding for classes. + spatial_dims: number of spatial dims. text_dim: dimension of text embeddings. hidden_size: dimension of hidden features, compatible to different vision feature dimensions. - encoding: the text embedding type, default to use clip text pretrained weights + encoding: the text embedding type, default to use clip text pretrained weights. + pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False. """ super().__init__() self.encoding = encoding + self.spatial_dims = spatial_dims + if spatial_dims not in (2, 3): + raise ValueError("spatial dimension should be 2 or 3.") + if self.encoding == 'rand_embedding': self.text_embedding = nn.Embedding(out_channels, hidden_size) elif self.encoding == 'clip_embedding': self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) + if pretrained: + model_url = url_map["clip_encoding_univeral_model_31"] + pretrain_state_dict = model_zoo.load_url(model_url) + self.text_embedding.data = pretrain_state_dict.float() + print('load word embedding: {}'.format(self.encoding)) self.text_to_vision = nn.Linear(text_dim, hidden_size) else: raise Exception(f'{self.encoding} is not implemented, please add your own') @@ -54,8 +72,13 @@ def __init__( def forward(self): if self.encoding == 'clip_embedding': test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) - test_encoding = test_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) else: # text embedding as random initialized 'rand_embedding' - test_encoding = self.text_embedding.weight.unsqueeze(2).unsqueeze(2).unsqueeze(2) + test_encoding = self.text_embedding.weight + + if self.spatial_dims == 3: + test_encoding = test_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) + elif self.spatial_dims == 2: + test_encoding = test_encoding.unsqueeze(2).unsqueeze(2) + return test_encoding diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py new file mode 100644 index 0000000000..a4ee3b1550 --- /dev/null +++ b/tests/test_text_encoding.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from monai.networks.blocks.text_embedding import TextEncoder + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +class TestTextEncoder(unittest.TestCase): + def test_test_encoding_shape(self): + # test 2D encoder + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, pretrained=True).to(device) + text_encoding = text_encoder() + print(text_encoding.shape) + self.assertEqual(text_encoding.shape, (32,256,1,1)) + + # test 3D encoder + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, pretrained=True).to(device) + text_encoding = text_encoder() + print(text_encoding.shape) + self.assertEqual(text_encoding.shape, (32,256,1,1,1)) + + # test random enbedding + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) + text_encoding = text_encoder() + print(text_encoding.shape) + self.assertEqual(text_encoding.shape, (32,256,1,1,1)) + +if __name__ == "__main__": + unittest.main() From 4ba43a9681025fb7f5fd2b75b488ca96b9fc8b96 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Wed, 12 Apr 2023 12:20:36 -0700 Subject: [PATCH 14/23] loadable options Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 29 +++++++++++++------------ 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 50a8bf73f0..d55b075c0f 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -16,7 +16,7 @@ from torch.utils import model_zoo url_map = { - "clip_encoding_univeral_model_31": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth", + "clip_encoding_univeral_model_32": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth", } @@ -37,7 +37,7 @@ def __init__( spatial_dims: int = 3, text_dim: int = 512, hidden_size: int = 256, - encoding: str = "clip_embedding", + encoding: str = "rand_embedding", pretrained: bool = False ) -> None: """ @@ -58,23 +58,24 @@ def __init__( if self.encoding == 'rand_embedding': self.text_embedding = nn.Embedding(out_channels, hidden_size) - elif self.encoding == 'clip_embedding': - self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) - if pretrained: - model_url = url_map["clip_encoding_univeral_model_31"] - pretrain_state_dict = model_zoo.load_url(model_url) - self.text_embedding.data = pretrain_state_dict.float() - print('load word embedding: {}'.format(self.encoding)) - self.text_to_vision = nn.Linear(text_dim, hidden_size) else: - raise Exception(f'{self.encoding} is not implemented, please add your own') + if self.encoding in url_map: + self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) + if pretrained: + model_url = url_map[self.encoding] + pretrain_state_dict = model_zoo.load_url(model_url) + self.text_embedding.data = pretrain_state_dict.float() + print('load text embedding: {}'.format(self.encoding)) + self.text_to_vision = nn.Linear(text_dim, hidden_size) + else: + raise Exception(f'{self.encoding} is not implemented, please add your own') def forward(self): - if self.encoding == 'clip_embedding': - test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) - else: + if self.encoding == 'rand_embedding': # text embedding as random initialized 'rand_embedding' test_encoding = self.text_embedding.weight + else: + test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) if self.spatial_dims == 3: test_encoding = test_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) From 6b3d8fea78f9fa776b111c9a694efea37885b8c6 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Wed, 12 Apr 2023 12:23:27 -0700 Subject: [PATCH 15/23] remove print Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 1 - tests/test_text_encoding.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index d55b075c0f..c65ecc1cdf 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -65,7 +65,6 @@ def __init__( model_url = url_map[self.encoding] pretrain_state_dict = model_zoo.load_url(model_url) self.text_embedding.data = pretrain_state_dict.float() - print('load text embedding: {}'.format(self.encoding)) self.text_to_vision = nn.Linear(text_dim, hidden_size) else: raise Exception(f'{self.encoding} is not implemented, please add your own') diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py index a4ee3b1550..787f118e4f 100644 --- a/tests/test_text_encoding.py +++ b/tests/test_text_encoding.py @@ -22,13 +22,13 @@ class TestTextEncoder(unittest.TestCase): def test_test_encoding_shape(self): # test 2D encoder - text_encoder = TextEncoder(spatial_dims=2, out_channels=32, pretrained=True).to(device) + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) text_encoding = text_encoder() print(text_encoding.shape) self.assertEqual(text_encoding.shape, (32,256,1,1)) # test 3D encoder - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, pretrained=True).to(device) + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) text_encoding = text_encoder() print(text_encoding.shape) self.assertEqual(text_encoding.shape, (32,256,1,1,1)) From af30d79e4e660c179330202ff865e9cf90f127ec Mon Sep 17 00:00:00 2001 From: tangy5 Date: Wed, 12 Apr 2023 12:28:52 -0700 Subject: [PATCH 16/23] add skip if downloading fails Signed-off-by: tangy5 --- tests/test_text_encoding.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py index 787f118e4f..952a583ae4 100644 --- a/tests/test_text_encoding.py +++ b/tests/test_text_encoding.py @@ -15,29 +15,33 @@ import torch from monai.networks.blocks.text_embedding import TextEncoder +from tests.utils import skip_if_downloading_fails device = "cuda" if torch.cuda.is_available() else "cpu" class TestTextEncoder(unittest.TestCase): def test_test_encoding_shape(self): - # test 2D encoder - text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) - text_encoding = text_encoder() - print(text_encoding.shape) - self.assertEqual(text_encoding.shape, (32,256,1,1)) - - # test 3D encoder - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) + with skip_if_downloading_fails(): + # test 2D encoder + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) + text_encoding = text_encoder() + self.assertEqual(text_encoding.shape, (32,256,1,1)) + + # test 3D encoder + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) + text_encoding = text_encoder() + self.assertEqual(text_encoding.shape, (32,256,1,1,1)) + + # test random enbedding 3D + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) text_encoding = text_encoder() - print(text_encoding.shape) self.assertEqual(text_encoding.shape, (32,256,1,1,1)) - # test random enbedding - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) + # test random enbedding 2D + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) text_encoding = text_encoder() - print(text_encoding.shape) - self.assertEqual(text_encoding.shape, (32,256,1,1,1)) + self.assertEqual(text_encoding.shape, (32,256,1,1)) if __name__ == "__main__": unittest.main() From 900729d73f3ba8962053a58a46a4298dce9d8e31 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Wed, 12 Apr 2023 12:50:07 -0700 Subject: [PATCH 17/23] update pretrained load logic Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index c65ecc1cdf..f4dbf72bb3 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -59,15 +59,16 @@ def __init__( if self.encoding == 'rand_embedding': self.text_embedding = nn.Embedding(out_channels, hidden_size) else: - if self.encoding in url_map: - self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) - if pretrained: - model_url = url_map[self.encoding] - pretrain_state_dict = model_zoo.load_url(model_url) - self.text_embedding.data = pretrain_state_dict.float() - self.text_to_vision = nn.Linear(text_dim, hidden_size) + self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) + + if pretrained: + model_url = url_map[self.encoding] + pretrain_state_dict = model_zoo.load_url(model_url) + self.text_embedding.data = pretrain_state_dict.float() else: - raise Exception(f'{self.encoding} is not implemented, please add your own') + print(f'{self.encoding} is not implemented, and can not be downloaded, please load your own') + + self.text_to_vision = nn.Linear(text_dim, hidden_size) def forward(self): if self.encoding == 'rand_embedding': From 2be586744213c3ab2b36f704695b912b6ad35697 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Apr 2023 19:52:02 +0000 Subject: [PATCH 18/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index f4dbf72bb3..4e266987c6 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -55,7 +55,7 @@ def __init__( self.spatial_dims = spatial_dims if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") - + if self.encoding == 'rand_embedding': self.text_embedding = nn.Embedding(out_channels, hidden_size) else: From c2a175564dc7f1f04894f46c5a9804752cc603e1 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Thu, 13 Apr 2023 11:48:47 -0700 Subject: [PATCH 19/23] fix cpu only test and others Signed-off-by: tangy5 --- monai/networks/blocks/text_embedding.py | 17 +++++++++-------- tests/test_text_encoding.py | 11 ++++------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 4e266987c6..a5637632d1 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -37,8 +37,8 @@ def __init__( spatial_dims: int = 3, text_dim: int = 512, hidden_size: int = 256, - encoding: str = "rand_embedding", - pretrained: bool = False + encoding: str = "clip_encoding_univeral_model_32", + pretrained: bool = True ) -> None: """ Args: @@ -63,7 +63,7 @@ def __init__( if pretrained: model_url = url_map[self.encoding] - pretrain_state_dict = model_zoo.load_url(model_url) + pretrain_state_dict = model_zoo.load_url(model_url,map_location="cpu") self.text_embedding.data = pretrain_state_dict.float() else: print(f'{self.encoding} is not implemented, and can not be downloaded, please load your own') @@ -73,13 +73,14 @@ def __init__( def forward(self): if self.encoding == 'rand_embedding': # text embedding as random initialized 'rand_embedding' - test_encoding = self.text_embedding.weight + text_embedding = self.text_embedding.weight else: - test_encoding = nn.functional.relu(self.text_to_vision(self.text_embedding)) + print(self.text_embedding) + text_embedding = nn.functional.relu(self.text_to_vision(self.text_embedding)) if self.spatial_dims == 3: - test_encoding = test_encoding.unsqueeze(2).unsqueeze(2).unsqueeze(2) + text_embedding = text_embedding.unsqueeze(2).unsqueeze(2).unsqueeze(2) elif self.spatial_dims == 2: - test_encoding = test_encoding.unsqueeze(2).unsqueeze(2) + text_embedding = text_embedding.unsqueeze(2).unsqueeze(2) - return test_encoding + return text_embedding diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py index 952a583ae4..fcd32693e1 100644 --- a/tests/test_text_encoding.py +++ b/tests/test_text_encoding.py @@ -17,29 +17,26 @@ from monai.networks.blocks.text_embedding import TextEncoder from tests.utils import skip_if_downloading_fails -device = "cuda" if torch.cuda.is_available() else "cpu" - - class TestTextEncoder(unittest.TestCase): def test_test_encoding_shape(self): with skip_if_downloading_fails(): # test 2D encoder - text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True) text_encoding = text_encoder() self.assertEqual(text_encoding.shape, (32,256,1,1)) # test 3D encoder - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True).to(device) + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True) text_encoding = text_encoder() self.assertEqual(text_encoding.shape, (32,256,1,1,1)) # test random enbedding 3D - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) + text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True) text_encoding = text_encoder() self.assertEqual(text_encoding.shape, (32,256,1,1,1)) # test random enbedding 2D - text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True).to(device) + text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True) text_encoding = text_encoder() self.assertEqual(text_encoding.shape, (32,256,1,1)) From 4392a464e527de4169bfeda5240a65a5042c5fa7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Apr 2023 18:49:35 +0000 Subject: [PATCH 20/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_text_encoding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py index fcd32693e1..3067b8d5bf 100644 --- a/tests/test_text_encoding.py +++ b/tests/test_text_encoding.py @@ -13,7 +13,6 @@ import unittest -import torch from monai.networks.blocks.text_embedding import TextEncoder from tests.utils import skip_if_downloading_fails From 447509b42a3d84c61caf1a978a22bbf699b389c7 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 13 Apr 2023 19:12:26 +0000 Subject: [PATCH 21/23] [MONAI] code formatting Signed-off-by: monai-bot --- monai/networks/blocks/text_embedding.py | 15 ++++++++------- tests/test_text_encoding.py | 18 ++++++++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index a5637632d1..f56a5bd007 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -16,7 +16,7 @@ from torch.utils import model_zoo url_map = { - "clip_encoding_univeral_model_32": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth", + "clip_encoding_univeral_model_32": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth" } @@ -31,6 +31,7 @@ class TextEncoder(nn.Module): Connecting text and medical 3D image, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " """ + def __init__( self, out_channels: int, @@ -38,7 +39,7 @@ def __init__( text_dim: int = 512, hidden_size: int = 256, encoding: str = "clip_encoding_univeral_model_32", - pretrained: bool = True + pretrained: bool = True, ) -> None: """ Args: @@ -56,22 +57,22 @@ def __init__( if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") - if self.encoding == 'rand_embedding': + if self.encoding == "rand_embedding": self.text_embedding = nn.Embedding(out_channels, hidden_size) else: - self.register_buffer('text_embedding', torch.randn(out_channels, text_dim)) + self.register_buffer("text_embedding", torch.randn(out_channels, text_dim)) if pretrained: model_url = url_map[self.encoding] - pretrain_state_dict = model_zoo.load_url(model_url,map_location="cpu") + pretrain_state_dict = model_zoo.load_url(model_url, map_location="cpu") self.text_embedding.data = pretrain_state_dict.float() else: - print(f'{self.encoding} is not implemented, and can not be downloaded, please load your own') + print(f"{self.encoding} is not implemented, and can not be downloaded, please load your own") self.text_to_vision = nn.Linear(text_dim, hidden_size) def forward(self): - if self.encoding == 'rand_embedding': + if self.encoding == "rand_embedding": # text embedding as random initialized 'rand_embedding' text_embedding = self.text_embedding.weight else: diff --git a/tests/test_text_encoding.py b/tests/test_text_encoding.py index 3067b8d5bf..06c95c4111 100644 --- a/tests/test_text_encoding.py +++ b/tests/test_text_encoding.py @@ -16,28 +16,34 @@ from monai.networks.blocks.text_embedding import TextEncoder from tests.utils import skip_if_downloading_fails + class TestTextEncoder(unittest.TestCase): def test_test_encoding_shape(self): with skip_if_downloading_fails(): # test 2D encoder - text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True) + text_encoder = TextEncoder( + spatial_dims=2, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True + ) text_encoding = text_encoder() - self.assertEqual(text_encoding.shape, (32,256,1,1)) + self.assertEqual(text_encoding.shape, (32, 256, 1, 1)) # test 3D encoder - text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True) + text_encoder = TextEncoder( + spatial_dims=3, out_channels=32, encoding="clip_encoding_univeral_model_32", pretrained=True + ) text_encoding = text_encoder() - self.assertEqual(text_encoding.shape, (32,256,1,1,1)) + self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1)) # test random enbedding 3D text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding="rand_embedding", pretrained=True) text_encoding = text_encoder() - self.assertEqual(text_encoding.shape, (32,256,1,1,1)) + self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1)) # test random enbedding 2D text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding="rand_embedding", pretrained=True) text_encoding = text_encoder() - self.assertEqual(text_encoding.shape, (32,256,1,1)) + self.assertEqual(text_encoding.shape, (32, 256, 1, 1)) + if __name__ == "__main__": unittest.main() From bd6c4e32c8b102747ae6ddf3a631f9eb8b2fe755 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 13 Apr 2023 20:17:44 +0100 Subject: [PATCH 22/23] fixes Signed-off-by: Wenqi Li --- monai/networks/blocks/text_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index f56a5bd007..56fedc969c 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -16,7 +16,10 @@ from torch.utils import model_zoo url_map = { - "clip_encoding_univeral_model_32": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/clip_encoding_univeral_model.pth" + "clip_encoding_univeral_model_32": ( + "https://github.com/Project-MONAI/MONAI-extra-test-data/" + "releases/download/0.8.1/clip_encoding_univeral_model.pth" + ) } From 72bf74a35b4dcaab63ce0f4d4085d0526c2f327f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 13 Apr 2023 20:19:16 +0100 Subject: [PATCH 23/23] fixes Signed-off-by: Wenqi Li --- monai/networks/blocks/text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/text_embedding.py b/monai/networks/blocks/text_embedding.py index 56fedc969c..187922b835 100644 --- a/monai/networks/blocks/text_embedding.py +++ b/monai/networks/blocks/text_embedding.py @@ -68,7 +68,7 @@ def __init__( if pretrained: model_url = url_map[self.encoding] pretrain_state_dict = model_zoo.load_url(model_url, map_location="cpu") - self.text_embedding.data = pretrain_state_dict.float() + self.text_embedding.data = pretrain_state_dict.float() # type: ignore else: print(f"{self.encoding} is not implemented, and can not be downloaded, please load your own")