diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 1020ff1026..36d62752d4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -500,9 +500,9 @@ Nets .. autoclass:: Critic :members: -`VLTransformers` +`Transchex` ~~~~~~~~~~~~~~~~ -.. autoclass:: VLTransformers +.. autoclass:: Transchex :members: `NetAdapter` diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 0ddda1d6dd..6076fcbe3d 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -80,17 +80,9 @@ seresnext101, ) from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel +from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex from .unet import UNet, Unet, unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT -from .vltransformer import ( - BertAttention, - BertMixedLayer, - BertOutput, - BertPreTrainedModel, - MultiModal, - Pooler, - VLTransformers, -) from .vnet import VNet diff --git a/monai/networks/nets/vltransformer.py b/monai/networks/nets/transchex.py similarity index 98% rename from monai/networks/nets/vltransformer.py rename to monai/networks/nets/transchex.py index 23a1a39ded..1ec5039e6a 100644 --- a/monai/networks/nets/vltransformer.py +++ b/monai/networks/nets/transchex.py @@ -34,7 +34,7 @@ "BertMixedLayer", "Pooler", "MultiModal", - "VLTransformers", + "Transchex", ] @@ -266,9 +266,10 @@ def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_m return hidden_state_mixed -class VLTransformers(torch.nn.Module): +class Transchex(torch.nn.Module): """ - Vision Language Multimodal Transformers + TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language + Transformers for Chest X-ray Analysis" """ def __init__( @@ -321,7 +322,7 @@ def __init__( # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head - net = VLTransformers(in_channels=3, + net = Transchex(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, @@ -330,7 +331,7 @@ def __init__( drop_out=0.2) """ - super(VLTransformers, self).__init__() + super(Transchex, self).__init__() bert_config = { "attention_probs_dropout_prob": attention_probs_dropout_prob, "classifier_dropout": None, diff --git a/tests/min_tests.py b/tests/min_tests.py index bac6521889..f47a06b3bb 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,7 +140,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", - "test_vltransformer", + "test_transchex", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_vltransformer.py b/tests/test_transchex.py similarity index 90% rename from tests/test_vltransformer.py rename to tests/test_transchex.py index a92a9bf79a..716d3cc52e 100644 --- a/tests/test_vltransformer.py +++ b/tests/test_transchex.py @@ -15,9 +15,9 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets.vltransformer import VLTransformers +from monai.networks.nets.transchex import Transchex -TEST_CASE_VLTransformers = [] +TEST_CASE_TRANSCHEX = [] for drop_out in [0.4]: for in_channels in [3]: for img_size in [224]: @@ -39,20 +39,20 @@ }, (2, num_classes), # type: ignore ] - TEST_CASE_VLTransformers.append(test_case) + TEST_CASE_TRANSCHEX.append(test_case) class TestPatchEmbeddingBlock(unittest.TestCase): - @parameterized.expand(TEST_CASE_VLTransformers) + @parameterized.expand(TEST_CASE_TRANSCHEX) def test_shape(self, input_param, expected_shape): - net = VLTransformers(**input_param) + net = Transchex(**input_param) with eval_mode(net): result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): with self.assertRaises(ValueError): - VLTransformers( + Transchex( in_channels=3, img_size=(128, 128), patch_size=(16, 16), @@ -64,7 +64,7 @@ def test_ill_arg(self): ) with self.assertRaises(ValueError): - VLTransformers( + Transchex( in_channels=1, img_size=(97, 97), patch_size=(16, 16),