Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions monai/networks/nets/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,24 +179,25 @@ def __init__(
res_block=res_block,
)
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
self.proj_view_shape = list(self.feat_size) + [self.hidden_size]

def proj_feat(self, x, hidden_size, feat_size):
new_view = (x.size(0), *feat_size, hidden_size)
def proj_feat(self, x):
new_view = [x.size(0)] + self.proj_view_shape
x = x.view(new_view)
new_axes = (0, len(x.shape) - 1) + tuple(d + 1 for d in range(len(feat_size)))
x = x.permute(new_axes).contiguous()
x = x.permute(self.proj_axes).contiguous()
return x

def forward(self, x_in):
x, hidden_states_out = self.vit(x_in)
enc1 = self.encoder1(x_in)
x2 = hidden_states_out[3]
enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))
enc2 = self.encoder2(self.proj_feat(x2))
x3 = hidden_states_out[6]
enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))
enc3 = self.encoder3(self.proj_feat(x3))
x4 = hidden_states_out[9]
enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))
dec4 = self.proj_feat(x, self.hidden_size, self.feat_size)
enc4 = self.encoder4(self.proj_feat(x4))
dec4 = self.proj_feat(x)
dec3 = self.decoder5(dec4, enc4)
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
Expand Down
14 changes: 13 additions & 1 deletion tests/test_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks import eval_mode
from monai.networks.nets.unetr import UNETR
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

TEST_CASE_UNETR = []
for dropout_rate in [0.4]:
Expand Down Expand Up @@ -52,7 +53,7 @@
TEST_CASE_UNETR.append(test_case)


class TestPatchEmbeddingBlock(unittest.TestCase):
class TestUNETR(unittest.TestCase):
@parameterized.expand(TEST_CASE_UNETR)
def test_shape(self, input_param, input_shape, expected_shape):
net = UNETR(**input_param)
Expand Down Expand Up @@ -117,6 +118,17 @@ def test_ill_arg(self):
dropout_rate=0.2,
)

@parameterized.expand(TEST_CASE_UNETR)
@SkipIfBeforePyTorchVersion((1, 9))
def test_script(self, input_param, input_shape, _):
net = UNETR(**(input_param))
net.eval()
with torch.no_grad():
torch.jit.script(net)

test_data = torch.randn(input_shape)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
TEST_CASE_Vit.append(test_case)


class TestPatchEmbeddingBlock(unittest.TestCase):
class TestViT(unittest.TestCase):
@parameterized.expand(TEST_CASE_Vit)
def test_shape(self, input_param, input_shape, expected_shape):
net = ViT(**input_param)
Expand Down