From 6b58da06b3c779bdb9d44aa57bc03cb555f74915 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Thu, 21 Mar 2024 15:00:44 +0100 Subject: [PATCH] Fixing requires_grad for sincos positional encoding in monai/networks/blocks/patchembedding.py and associated tests Signed-off-by: Lucas Robinet --- monai/networks/blocks/patchembedding.py | 1 + monai/networks/blocks/pos_embed_utils.py | 2 +- tests/test_patchembedding.py | 26 ++++++++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 7d56045814..44774ce5da 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -123,6 +123,7 @@ def __init__( with torch.no_grad(): pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) self.position_embeddings.data.copy_(pos_embeddings.float()) + self.position_embeddings.requires_grad = False else: raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index e03553307e..21586e56da 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -46,7 +46,7 @@ def build_sincos_position_embedding( temperature (float): The temperature for the sin-cos position embedding. Returns: - pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter. + pos_embed (nn.Parameter): The sin-cos position embedding as a fixed parameter. """ if spatial_dims == 2: diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index f8610d9214..d059145033 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -93,6 +93,32 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_sincos_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="sincos", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, False) + + def test_learnable_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="learnable", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, True) + def test_ill_arg(self): with self.assertRaises(ValueError): PatchEmbeddingBlock(