diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index c2239450f2..8d0aadafd6 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -53,6 +53,10 @@ def __init__( self.inter_channels = inter_channels if inter_channels is not None else list() self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels)) + # The number of channels and strides should match + if len(channels) != len(strides): + raise ValueError("Autoencoder expects matching number of channels and strides") + self.encoded_channels = in_channels decode_channel_list = list(channels[-2::-1]) + [out_channels] diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index a7749d7f3a..86b31e0361 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -62,6 +62,15 @@ CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +TEST_CASE_FAIL = { # 2-channel 2D, should fail because of stride/channel mismatch. + "dimensions": 2, + "in_channels": 2, + "out_channels": 2, + "channels": (4, 8, 16), + "strides": (2, 2), +} + + class TestAutoEncoder(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): @@ -76,6 +85,10 @@ def test_script(self): test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data) + def test_channel_stride_difference(self): + with self.assertRaises(ValueError): + net = AutoEncoder(**TEST_CASE_FAIL) + if __name__ == "__main__": unittest.main()