Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ Nets
.. autoclass:: Critic
:members:

`VLTransformers`
`Transchex`
~~~~~~~~~~~~~~~~
.. autoclass:: VLTransformers
.. autoclass:: Transchex
:members:

`NetAdapter`
Expand Down
10 changes: 1 addition & 9 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,9 @@
seresnext101,
)
from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
from .unet import UNet, Unet, unet
from .unetr import UNETR
from .varautoencoder import VarAutoEncoder
from .vit import ViT
from .vltransformer import (
BertAttention,
BertMixedLayer,
BertOutput,
BertPreTrainedModel,
MultiModal,
Pooler,
VLTransformers,
)
from .vnet import VNet
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"BertMixedLayer",
"Pooler",
"MultiModal",
"VLTransformers",
"Transchex",
]


Expand Down Expand Up @@ -266,9 +266,10 @@ def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_m
return hidden_state_mixed


class VLTransformers(torch.nn.Module):
class Transchex(torch.nn.Module):
"""
Vision Language Multimodal Transformers
TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language
Transformers for Chest X-ray Analysis"
"""

def __init__(
Expand Down Expand Up @@ -321,7 +322,7 @@ def __init__(

# for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers,
# 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head
net = VLTransformers(in_channels=3,
net = Transchex(in_channels=3,
img_size=(224, 224),
num_classes=3,
num_language_layers=2,
Expand All @@ -330,7 +331,7 @@ def __init__(
drop_out=0.2)

"""
super(VLTransformers, self).__init__()
super(Transchex, self).__init__()
bert_config = {
"attention_probs_dropout_prob": attention_probs_dropout_prob,
"classifier_dropout": None,
Expand Down
2 changes: 1 addition & 1 deletion tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def run_testsuit():
"test_zoom",
"test_zoom_affine",
"test_zoomd",
"test_vltransformer",
"test_transchex",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
14 changes: 7 additions & 7 deletions tests/test_vltransformer.py → tests/test_transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.vltransformer import VLTransformers
from monai.networks.nets.transchex import Transchex

TEST_CASE_VLTransformers = []
TEST_CASE_TRANSCHEX = []
for drop_out in [0.4]:
for in_channels in [3]:
for img_size in [224]:
Expand All @@ -39,20 +39,20 @@
},
(2, num_classes), # type: ignore
]
TEST_CASE_VLTransformers.append(test_case)
TEST_CASE_TRANSCHEX.append(test_case)


class TestPatchEmbeddingBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_VLTransformers)
@parameterized.expand(TEST_CASE_TRANSCHEX)
def test_shape(self, input_param, expected_shape):
net = VLTransformers(**input_param)
net = Transchex(**input_param)
with eval_mode(net):
result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224)))
self.assertEqual(result.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(ValueError):
VLTransformers(
Transchex(
in_channels=3,
img_size=(128, 128),
patch_size=(16, 16),
Expand All @@ -64,7 +64,7 @@ def test_ill_arg(self):
)

with self.assertRaises(ValueError):
VLTransformers(
Transchex(
in_channels=1,
img_size=(97, 97),
patch_size=(16, 16),
Expand Down