From b4e4de85b3fd0e00b79bc7dd2e2dfdafc32b75c4 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Tue, 20 Dec 2022 22:32:52 +0000 Subject: [PATCH 1/2] Added the MMD Metric and tests Signed-off-by: Petru-Daniel Tudosiu --- generative/metrics/__init__.py | 1 + generative/metrics/mmd.py | 94 ++++++++++++++++++++++++++++++++ tests/test_compute_mmd_metric.py | 47 ++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 generative/metrics/mmd.py create mode 100644 tests/test_compute_mmd_metric.py diff --git a/generative/metrics/__init__.py b/generative/metrics/__init__.py index df6a5858..fec106c1 100644 --- a/generative/metrics/__init__.py +++ b/generative/metrics/__init__.py @@ -9,4 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .mmd import MMD from .ms_ssim import MSSSIM diff --git a/generative/metrics/mmd.py b/generative/metrics/mmd.py new file mode 100644 index 00000000..e77f6866 --- /dev/null +++ b/generative/metrics/mmd.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +from monai.metrics.regression import RegressionMetric +from monai.utils import MetricReduction + + +class MMD(RegressionMetric): + """ + Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two + distributions. It is a non-negative metric where a smaller value indicates a closer match between the two + distributions. + + Args: + y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace + filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a + feature extractor or an Identity function. + y_pred_transform: Callable to transform the y_pred tensor before computing the metric. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, available + reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, + `"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. This parameter is ignored due to + the mathematical formulation of MMD. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here + `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + This parameter is ignored due to the mathematical formulation of MMD. + + Ref: + Gretton, A., Borgwardt, K.M., Rasch, M.J., Schölkopf, B. and Smola, A., 2012. + A kernel two-sample test. + The Journal of Machine Learning Research, 13(1), pp.723-773. + """ + + def __init__( + self, + y_transform: Optional[Callable] = None, + y_pred_transform: Optional[Callable] = None, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__(reduction=reduction, get_not_nans=get_not_nans) + + self.y_transform = y_transform + self.y_pred_transform = y_pred_transform + + def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + """ + Args: + y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. + y_pred: second sample (e.g., the reconstructed image). It has similar shape as y. + weights: weights for each sample. It has shape (B,1) and the values should be non-negative. + """ + + # Beta and Gamma are not calculated since torch.mean is used at return + beta = 1.0 + gamma = 2.0 + + if self.y_transform is not None: + y = self.y_transform(y) + + if self.y_pred_transform is not None: + y_pred = self.y_pred_transform(y_pred) + + if y_pred.shape != y.shape: + raise ValueError( + f"y_pred and y shapes dont match after being processed by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}" + ) + + for d in range(len(y.shape) - 1, 1, -1): + y = y.squeeze(dim=d) + y_pred = y_pred.squeeze(dim=d) + + y = y.view(y.shape[0], -1) + y_pred = y_pred.view(y_pred.shape[0], -1) + + y_y = torch.mm(y, y_pred.t()) + y_pred_y_pred = torch.mm(y_pred, y_pred.t()) + y_pred_y = torch.mm(y_pred, y.t()) + + y_y = y_y / y.shape[1] + y_pred_y_pred = y_pred_y_pred / y.shape[1] + y_pred_y = y_pred_y / y.shape[1] + + # Ref. 1 Eq. 3 (found under Lemma 6) + return beta * (torch.mean(y_y) + torch.mean(y_pred_y_pred)) - gamma * torch.mean(y_pred_y) diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py new file mode 100644 index 00000000..e7e26f3a --- /dev/null +++ b/tests/test_compute_mmd_metric.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from generative.metrics import MMD + +TEST_CASES = [ + [ + {"y_transform": None, "y_pred_transform": None}, + {"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])}, + 0.0, + ], + [ + {"y_transform": None, "y_pred_transform": None}, + {"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])}, + 0.0, + ], +] + + +class TestMMDMetric(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_results(self, input_param, input_data, expected_val): + results = MMD(**input_param)._compute_metric(**input_data) + np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_if_inputs_different_shapes(self): + with self.assertRaises(ValueError): + MMD()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + + +if __name__ == "__main__": + unittest.main() From 98800c9258114f35a8647d134597008baf5e9b1c Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Sun, 8 Jan 2023 09:24:12 +0000 Subject: [PATCH 2/2] Fixed MMD yy calculations and docs. --- generative/metrics/mmd.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/generative/metrics/mmd.py b/generative/metrics/mmd.py index e77f6866..3c652df8 100644 --- a/generative/metrics/mmd.py +++ b/generative/metrics/mmd.py @@ -21,6 +21,8 @@ class MMD(RegressionMetric): distributions. It is a non-negative metric where a smaller value indicates a closer match between the two distributions. + Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773. + Args: y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a @@ -34,10 +36,6 @@ class MMD(RegressionMetric): `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. This parameter is ignored due to the mathematical formulation of MMD. - Ref: - Gretton, A., Borgwardt, K.M., Rasch, M.J., Schölkopf, B. and Smola, A., 2012. - A kernel two-sample test. - The Journal of Machine Learning Research, 13(1), pp.723-773. """ def __init__( @@ -57,7 +55,6 @@ def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor Args: y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. y_pred: second sample (e.g., the reconstructed image). It has similar shape as y. - weights: weights for each sample. It has shape (B,1) and the values should be non-negative. """ # Beta and Gamma are not calculated since torch.mean is used at return @@ -82,7 +79,7 @@ def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor y = y.view(y.shape[0], -1) y_pred = y_pred.view(y_pred.shape[0], -1) - y_y = torch.mm(y, y_pred.t()) + y_y = torch.mm(y, y.t()) y_pred_y_pred = torch.mm(y_pred, y_pred.t()) y_pred_y = torch.mm(y_pred, y.t())