From 225645e94b7a8857d347d4999f111f2d8ad1fa82 Mon Sep 17 00:00:00 2001 From: myron Date: Thu, 28 Sep 2023 22:57:45 -0700 Subject: [PATCH] segresnet_ds better peak GPU mem Signed-off-by: myron --- monai/networks/nets/segresnet_ds.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 96250c7443..6430f5fdc9 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -119,8 +119,7 @@ def __init__( def forward(self, x): identity = x - x = self.conv1(self.act1(self.norm1(x))) - x = self.conv2(self.act2(self.norm2(x))) + x = self.conv2(self.act2(self.norm2(self.conv1(self.act1(self.norm1(x)))))) x += identity return x @@ -408,7 +407,7 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens i = 0 for level in self.up_layers: x = level["upsample"](x) - x = x + x_down[i] + x += x_down.pop(0) x = level["blocks"](x) if len(self.up_layers) - i <= self.dsdepth: