diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 3e496979bf..cb0c69d033 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -34,7 +34,6 @@ import math from collections.abc import Sequence from functools import reduce -from typing import Optional import numpy as np import torch @@ -2016,7 +2015,7 @@ def __init__( last_dim_flattened = int(reduce(lambda x, y: x * y, input_shape) * channels[-1]) - self.out: Optional[nn.Module] = nn.Sequential( + self.out: nn.Module = nn.Sequential( nn.Linear(last_dim_flattened, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) ) @@ -2063,9 +2062,6 @@ def forward( h = h.reshape(h.shape[0], -1) # 5. out - self.out = nn.Sequential( - nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) - ) output: torch.Tensor = self.out(h) return output