diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 2a5047c7..769a7ff3 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -80,10 +80,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ if not x.shape == y.shape: - raise ValueError( - f"Input images should have the same dimensions, \ - but got {x.shape} and {y.shape}." - ) + raise ValueError(f"Input images should have the same dimensions, but got {x.shape} and {y.shape}.") for d in range(len(x.shape) - 1, 1, -1): x = x.squeeze(dim=d) @@ -94,10 +91,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: elif len(x.shape) == 5: avg_pool = F.avg_pool3d else: - raise ValueError( - f"Input images should be 4-d or 5-d tensors, but \ - got {x.shape}" - ) + raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {x.shape}") if self.weights is None: # as per Ref 1 - Sec 3.2. @@ -109,14 +103,14 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for idx, shape_size in enumerate(x.shape[2:]): if shape_size % divisible_by != 0: raise ValueError( - f"Image size needs to be divisible by {divisible_by} but \ - dimension {idx + 2} has size {shape_size}" + f"Image size needs to be divisible by {divisible_by} but " + f"dimension {idx + 2} has size {shape_size}" ) if shape_size < bigger_than: raise ValueError( - f"Image size should be larger than {bigger_than} due to \ - the {len(self.weights) - 1} downsamplings in MS-SSIM." + f"Image size should be larger than {bigger_than} due to " + f"the {len(self.weights) - 1} downsamplings in MS-SSIM." ) levels = self.weights.shape[0]