diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 847993bd44..3321001af0 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -434,8 +434,7 @@ def __init__( self.dense4 = DenseBlock(spatial_dims, ndenselayer, num_init_features, densebn, densegrowth, 0.0) noutdense4 = num_init_features + densegrowth * ndenselayer - if psp_block_num > 0: - self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode) + self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode) self.final = Final(spatial_dims, psp_block_num + noutdense4, out_channels, upsample_mode) # Initialise parameters @@ -511,7 +510,7 @@ def forward(self, x): sum4 = self.up3(d3) + conv_x d4 = self.dense4(sum4) - if self.psp_block_num > 0 and self.psp is not None: + if self.psp_block_num > 0: psp = self.psp(d4) x = torch.cat((psp, d4), dim=1) else: