From de88aa4bd4123cb8670290b54af2e015386e3300 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Apr 2023 08:22:53 +0100 Subject: [PATCH 1/2] fixes #1349 Signed-off-by: Wenqi Li --- reconstruction/MRI_reconstruction/varnet_demo/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/reconstruction/MRI_reconstruction/varnet_demo/train.py b/reconstruction/MRI_reconstruction/varnet_demo/train.py index a068094ed7..8b46b0dd69 100644 --- a/reconstruction/MRI_reconstruction/varnet_demo/train.py +++ b/reconstruction/MRI_reconstruction/varnet_demo/train.py @@ -137,9 +137,6 @@ def trainer(args): model.load_state_dict(torch.load(args.checkpoint_dir)) print("resume training from a given checkpoint...") - # create the loss function - loss_function = SSIMLoss(spatial_dims=2).to(device) - # create the optimizer and the learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, args.lr_gamma) @@ -166,6 +163,9 @@ def trainer(args): final_shape = target.shape[-2:] max_value = torch.tensor(max_value).unsqueeze(0).to(device) + # create the loss function with data_range + loss_function = SSIMLoss(spatial_dims=2, data_range=max_value).to(device) + # iterate through all slices slice_dim = 1 # change this if another dimension is your slice dimension num_slices = input.shape[slice_dim] @@ -183,7 +183,7 @@ def trainer(args): cropper = SpatialCrop(roi_center=roi_center, roi_size=final_shape) output_crp = cropper(output).unsqueeze(0) - loss = loss_function(output_crp, tar, max_value) + loss = loss_function(output_crp, tar) loss.backward() optimizer.step() From 94f6f4da91dc06976a7dd4c54854628b2c190c0d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Apr 2023 08:59:19 +0100 Subject: [PATCH 2/2] update Signed-off-by: Wenqi Li --- reconstruction/MRI_reconstruction/varnet_demo/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/reconstruction/MRI_reconstruction/varnet_demo/train.py b/reconstruction/MRI_reconstruction/varnet_demo/train.py index 8b46b0dd69..6e8c3d3949 100644 --- a/reconstruction/MRI_reconstruction/varnet_demo/train.py +++ b/reconstruction/MRI_reconstruction/varnet_demo/train.py @@ -137,6 +137,9 @@ def trainer(args): model.load_state_dict(torch.load(args.checkpoint_dir)) print("resume training from a given checkpoint...") + # create the loss function + loss_function = SSIMLoss(spatial_dims=2).to(device) + # create the optimizer and the learning rate scheduler optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, args.lr_gamma) @@ -162,9 +165,7 @@ def trainer(args): final_shape = target.shape[-2:] max_value = torch.tensor(max_value).unsqueeze(0).to(device) - - # create the loss function with data_range - loss_function = SSIMLoss(spatial_dims=2, data_range=max_value).to(device) + loss_function.ssim_metric.data_range = max_value # iterate through all slices slice_dim = 1 # change this if another dimension is your slice dimension