diff --git a/reconstruction/MRI_reconstruction/varnet_demo/train.py b/reconstruction/MRI_reconstruction/varnet_demo/train.py index a068094ed7..6e8c3d3949 100644 --- a/reconstruction/MRI_reconstruction/varnet_demo/train.py +++ b/reconstruction/MRI_reconstruction/varnet_demo/train.py @@ -165,6 +165,7 @@ def trainer(args): final_shape = target.shape[-2:] max_value = torch.tensor(max_value).unsqueeze(0).to(device) + loss_function.ssim_metric.data_range = max_value # iterate through all slices slice_dim = 1 # change this if another dimension is your slice dimension @@ -183,7 +184,7 @@ def trainer(args): cropper = SpatialCrop(roi_center=roi_center, roi_size=final_shape) output_crp = cropper(output).unsqueeze(0) - loss = loss_function(output_crp, tar, max_value) + loss = loss_function(output_crp, tar) loss.backward() optimizer.step()