diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 51b474cfa8..5aab94c791 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -141,6 +141,24 @@ Metrics ------------------------------------- .. autoclass:: monai.metrics.regression.SSIMMetric +`Multi-scale structural similarity index measure` +------------------------------------------------- +.. autoclass:: MultiScaleSSIMMetric + +`Fréchet Inception Distance` +------------------------------ +.. autofunction:: compute_frechet_distance + +.. autoclass:: FIDMetric + :members: + +`Maximum Mean Discrepancy` +------------------------------ +.. autofunction:: compute_mmd + +.. autoclass:: MMDMetric + :members: + `Cumulative average` -------------------- .. autoclass:: CumulativeAverage @@ -156,6 +174,8 @@ Metrics .. autoclass:: MetricsReloadedCategorical :members: + + Utilities --------- .. automodule:: monai.metrics.utils diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 4af1b5760d..3809f59d2d 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -15,6 +15,7 @@ from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore +from .fid import FIDMetric, compute_frechet_distance from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance @@ -22,8 +23,18 @@ from .meandice import DiceHelper, DiceMetric, compute_dice from .meaniou import MeanIoU, compute_iou, compute_meaniou from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric +from .mmd import MMDMetric, compute_mmd from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality -from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric +from .regression import ( + MAEMetric, + MSEMetric, + MultiScaleSSIMMetric, + PSNRMetric, + RMSEMetric, + SSIMMetric, + compute_ms_ssim, + compute_ssim_and_cs, +) from .rocauc import ROCAUCMetric, compute_roc_auc from .surface_dice import SurfaceDiceMetric, compute_surface_dice from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance diff --git a/monai/metrics/fid.py b/monai/metrics/fid.py new file mode 100644 index 0000000000..194d596f67 --- /dev/null +++ b/monai/metrics/fid.py @@ -0,0 +1,111 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import numpy as np +import torch + +from monai.metrics.metric import Metric +from monai.utils import optional_import + +scipy, _ = optional_import("scipy") + + +class FIDMetric(Metric): + """ + Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors. + Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." + https://arxiv.org/abs/1706.08500. The inputs for this metric should be two groups of feature vectors (with format + (number images, number of features)) extracted from a pretrained network. + + Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet. + However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and + MedicalNet for 3D images). If the chosen model output is not a scalar, a global spatia average pooling should be + used. + """ + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return get_fid_score(y_pred, y) + + +def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes the FID score metric on a batch of feature vectors. + + Args: + y_pred: feature vectors extracted from a pretrained network run on generated images. + y: feature vectors extracted from a pretrained network run on images from the real data distribution. + """ + y = y.double() + y_pred = y_pred.double() + + if y.ndimension() > 2: + raise ValueError("Inputs should have (number images, number of features) shape.") + + mu_y_pred = torch.mean(y_pred, dim=0) + sigma_y_pred = _cov(y_pred, rowvar=False) + mu_y = torch.mean(y, dim=0) + sigma_y = _cov(y, rowvar=False) + + return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) + + +def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor: + """ + Estimate a covariance matrix of the variables. + + Args: + input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, + and each column a single observation of all those variables. + rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns. + Otherwise, the relationship is transposed: each column represents a variable, while the rows contain + observations. + """ + if input_data.dim() < 2: + input_data = input_data.view(1, -1) + + if not rowvar and input_data.size(0) != 1: + input_data = input_data.t() + + factor = 1.0 / (input_data.size(1) - 1) + input_data = input_data - torch.mean(input_data, dim=1, keepdim=True) + return factor * input_data.matmul(input_data.t()).squeeze() + + +def _sqrtm(input_data: torch.Tensor) -> torch.Tensor: + """Compute the square root of a matrix.""" + scipy_res, _ = scipy.linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False) + return torch.from_numpy(scipy_res) + + +def compute_frechet_distance( + mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6 +) -> torch.Tensor: + """The Frechet distance between multivariate normal distributions.""" + diff = mu_x - mu_y + + covmean = _sqrtm(sigma_x.mm(sigma_y)) + + # Product might be almost singular + if not torch.isfinite(covmean).all(): + print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates") + offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon + covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset)) + + # Numerical error might give slight imaginary component + if torch.is_complex(covmean): + if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3): + raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.") + covmean = covmean.real + + tr_covmean = torch.trace(covmean) + return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean diff --git a/monai/metrics/mmd.py b/monai/metrics/mmd.py new file mode 100644 index 0000000000..5ba4cdf1b4 --- /dev/null +++ b/monai/metrics/mmd.py @@ -0,0 +1,91 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +import torch + +from monai.metrics.metric import Metric + + +class MMDMetric(Metric): + """ + Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two + distributions. It is a non-negative metric where a smaller value indicates a closer match between the two + distributions. + + Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773. + + Args: + y_mapping: Callable to transform the y tensors before computing the metric. It is usually a Gaussian or Laplace + filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a + feature extractor or an Identity function., e.g. `y_mapping = lambda x: x.square()`. + """ + + def __init__(self, y_mapping: Callable | None = None) -> None: + super().__init__() + self.y_mapping = y_mapping + + def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + return compute_mmd(y, y_pred, self.y_mapping) + + +def compute_mmd(y: torch.Tensor, y_pred: torch.Tensor, y_mapping: Callable | None) -> torch.Tensor: + """ + Args: + y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. + y_pred: second sample (e.g., the reconstructed image). It has similar shape as y. + y_mapping: Callable to transform the y tensors before computing the metric. + """ + if y_pred.shape[0] == 1 or y.shape[0] == 1: + raise ValueError("MMD metric requires at least two samples in y and y_pred.") + + if y_mapping is not None: + y = y_mapping(y) + y_pred = y_mapping(y_pred) + + if y_pred.shape != y.shape: + raise ValueError( + "y_pred and y shapes dont match after being processed " + f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}" + ) + + for d in range(len(y.shape) - 1, 1, -1): + y = y.squeeze(dim=d) + y_pred = y_pred.squeeze(dim=d) + + y = y.view(y.shape[0], -1) + y_pred = y_pred.view(y_pred.shape[0], -1) + + y_y = torch.mm(y, y.t()) + y_pred_y_pred = torch.mm(y_pred, y_pred.t()) + y_pred_y = torch.mm(y_pred, y.t()) + + m = y.shape[0] + n = y_pred.shape[0] + + # Ref. 1 Eq. 3 (found under Lemma 6) + # term 1 + c1 = 1 / (m * (m - 1)) + a = torch.sum(y_y - torch.diag(torch.diagonal(y_y))) + + # term 2 + c2 = 1 / (n * (n - 1)) + b = torch.sum(y_pred_y_pred - torch.diag(torch.diagonal(y_pred_y_pred))) + + # term 3 + c3 = 2 / (m * n) + c = torch.sum(y_pred_y) + + mmd = c1 * a + c2 * b - c3 * c + return mmd diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index c315a2eac0..f37230f09e 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -441,3 +441,168 @@ def compute_ssim_and_cs( ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity return ssim_value_full_image, contrast_sensitivity + + +class MultiScaleSSIMMetric(RegressionMetric): + """ + Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM). + + MS-SSIM reference paper: + Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. "Multiscale structural + similarity for image quality assessment." In The Thirty-Seventh Asilomar Conference + on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). IEEE + + Args: + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. + k1: stability constant used in the luminance denominator + k2: stability constant used in the contrast denominator + weights: parameters for image similarity and contrast sensitivity at different resolution scores. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) + """ + + def __init__( + self, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int] = 11, + kernel_sigma: float | Sequence[float] = 1.5, + k1: float = 0.01, + k2: float = 0.03, + weights: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__(reduction=reduction, get_not_nans=get_not_nans) + + self.spatial_dims = spatial_dims + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size + + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + self.weights = weights + + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return compute_ms_ssim( + y_pred=y_pred, + y=y, + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + weights=self.weights, + ) + + +def compute_ms_ssim( + y_pred: torch.Tensor, + y: torch.Tensor, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int] = 11, + kernel_sigma: float | Sequence[float] = 1.5, + k1: float = 0.01, + k2: float = 0.03, + weights: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), +) -> torch.Tensor: + """ + Args: + y_pred: Predicted image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + y: Reference image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. + k1: stability constant used in the luminance denominator + k2: stability constant used in the contrast denominator + weights: parameters for image similarity and contrast sensitivity at different resolution scores. + Raises: + ValueError: when `y_pred` is not a 2D or 3D image. + """ + dims = y_pred.ndimension() + if spatial_dims == 2 and dims != 4: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width) when using {spatial_dims} " + f"spatial dimensions, got {dims}." + ) + + if spatial_dims == 3 and dims != 5: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {spatial_dims}" + f" spatial dimensions, got {dims}." + ) + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + # check if image have enough size for the number of downsamplings and the size of the kernel + weights_div = max(1, (len(weights) - 1)) ** 2 + y_pred_spatial_dims = y_pred.shape[2:] + for i in range(len(y_pred_spatial_dims)): + if y_pred_spatial_dims[i] // weights_div <= kernel_size[i] - 1: + raise ValueError( + f"For a given number of `weights` parameters {len(weights)} and kernel size " + f"{kernel_size[i]}, the image height must be larger than " + f"{(kernel_size[i] - 1) * weights_div}." + ) + + weights_tensor = torch.tensor(weights, device=y_pred.device, dtype=torch.float) + + avg_pool = getattr(F, f"avg_pool{spatial_dims}d") + + multiscale_list: list[torch.Tensor] = [] + for _ in range(len(weights_tensor)): + ssim, cs = compute_ssim_and_cs( + y_pred=y_pred, + y=y, + spatial_dims=spatial_dims, + data_range=data_range, + kernel_type=kernel_type, + kernel_size=kernel_size, + kernel_sigma=kernel_sigma, + k1=k1, + k2=k2, + ) + + cs_per_batch = cs.view(cs.shape[0], -1).mean(1) + + multiscale_list.append(torch.relu(cs_per_batch)) + y_pred = avg_pool(y_pred, kernel_size=2) + y = avg_pool(y, kernel_size=2) + + ssim = ssim.view(ssim.shape[0], -1).mean(1) + multiscale_list[-1] = torch.relu(ssim) + multiscale_list_tensor = torch.stack(multiscale_list) + + ms_ssim_value_full_image = torch.prod(multiscale_list_tensor ** weights_tensor.view(-1, 1), dim=0) + + ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean( + 1, keepdim=True + ) + + return ms_ssim_per_batch diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py new file mode 100644 index 0000000000..1c7c3273fe --- /dev/null +++ b/tests/test_compute_fid_metric.py @@ -0,0 +1,39 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.metrics import FIDMetric +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + + +@unittest.skipUnless(has_scipy, "Requires scipy") +class TestFIDMetric(unittest.TestCase): + def test_results(self): + x = torch.Tensor([[1, 2], [1, 2], [1, 2]]) + y = torch.Tensor([[2, 2], [1, 2], [1, 2]]) + results = FIDMetric()(x, y) + np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4) + + def test_input_dimensions(self): + with self.assertRaises(ValueError): + FIDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py new file mode 100644 index 0000000000..d1b69b3dfe --- /dev/null +++ b/tests/test_compute_mmd_metric.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import MMDMetric + +TEST_CASES = [ + [{"y_mapping": None}, {"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])}, 0.0], + [{"y_mapping": None}, {"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])}, 0.0], + [ + {"y_mapping": lambda x: x.square()}, + {"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])}, + 0.0, + ], + [ + {"y_mapping": lambda x: x.square()}, + {"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])}, + 0.0, + ], +] + + +class TestMMDMetric(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_results(self, input_param, input_data, expected_val): + metric = MMDMetric(**input_param) + results = metric(**input_data) + np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_if_inputs_different_shapes(self): + with self.assertRaises(ValueError): + MMDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + + def test_if_inputs_have_one_sample(self): + with self.assertRaises(ValueError): + MMDMetric()(torch.ones([1, 3, 144, 144]), torch.ones([1, 3, 144, 144])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py new file mode 100644 index 0000000000..4ebc5b7935 --- /dev/null +++ b/tests/test_compute_multiscalessim_metric.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.metrics import MultiScaleSSIMMetric +from monai.utils import set_determinism + + +class TestMultiScaleSSIMMetric(unittest.TestCase): + def test2d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.023176 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test2d_uniform(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="uniform", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.022655 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test3d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.061796 + self.assertTrue(expected_value - result.item() < 0.000001) + + def input_ill_input_shape2d(self): + metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) + + with self.assertRaises(ValueError): + metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64)) + + def input_ill_input_shape3d(self): + metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) + + with self.assertRaises(ValueError): + metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64)) + + def small_inputs(self): + metric = MultiScaleSSIMMetric(spatial_dims=2) + + with self.assertRaises(ValueError): + metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main()