diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index b22f0584a2..c53936d27f 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -179,24 +179,25 @@ def __init__( res_block=res_block, ) self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) + self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) + self.proj_view_shape = list(self.feat_size) + [self.hidden_size] - def proj_feat(self, x, hidden_size, feat_size): - new_view = (x.size(0), *feat_size, hidden_size) + def proj_feat(self, x): + new_view = [x.size(0)] + self.proj_view_shape x = x.view(new_view) - new_axes = (0, len(x.shape) - 1) + tuple(d + 1 for d in range(len(feat_size))) - x = x.permute(new_axes).contiguous() + x = x.permute(self.proj_axes).contiguous() return x def forward(self, x_in): x, hidden_states_out = self.vit(x_in) enc1 = self.encoder1(x_in) x2 = hidden_states_out[3] - enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) + enc2 = self.encoder2(self.proj_feat(x2)) x3 = hidden_states_out[6] - enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) + enc3 = self.encoder3(self.proj_feat(x3)) x4 = hidden_states_out[9] - enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) - dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) + enc4 = self.encoder4(self.proj_feat(x4)) + dec4 = self.proj_feat(x) dec3 = self.decoder5(dec4, enc4) dec2 = self.decoder4(dec3, enc3) dec1 = self.decoder3(dec2, enc2) diff --git a/tests/test_unetr.py b/tests/test_unetr.py index a8c7b7bf88..40619de9dc 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -16,6 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.unetr import UNETR +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASE_UNETR = [] for dropout_rate in [0.4]: @@ -52,7 +53,7 @@ TEST_CASE_UNETR.append(test_case) -class TestPatchEmbeddingBlock(unittest.TestCase): +class TestUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_UNETR) def test_shape(self, input_param, input_shape, expected_shape): net = UNETR(**input_param) @@ -117,6 +118,17 @@ def test_ill_arg(self): dropout_rate=0.2, ) + @parameterized.expand(TEST_CASE_UNETR) + @SkipIfBeforePyTorchVersion((1, 9)) + def test_script(self, input_param, input_shape, _): + net = UNETR(**(input_param)) + net.eval() + with torch.no_grad(): + torch.jit.script(net) + + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vit.py b/tests/test_vit.py index d5ae209e50..3ef847626d 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -55,7 +55,7 @@ TEST_CASE_Vit.append(test_case) -class TestPatchEmbeddingBlock(unittest.TestCase): +class TestViT(unittest.TestCase): @parameterized.expand(TEST_CASE_Vit) def test_shape(self, input_param, input_shape, expected_shape): net = ViT(**input_param)