Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generative/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .mmd import MMD
from .ms_ssim import MSSSIM
91 changes: 91 additions & 0 deletions generative/metrics/mmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union

import torch
from monai.metrics.regression import RegressionMetric
from monai.utils import MetricReduction


class MMD(RegressionMetric):
"""
Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two
distributions. It is a non-negative metric where a smaller value indicates a closer match between the two
distributions.

Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773.

Args:
y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace
filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a
feature extractor or an Identity function.
y_pred_transform: Callable to transform the y_pred tensor before computing the metric.
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, available
reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``,
`"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. This parameter is ignored due to
the mathematical formulation of MMD.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here
`not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
This parameter is ignored due to the mathematical formulation of MMD.

"""

def __init__(
self,
y_transform: Optional[Callable] = None,
y_pred_transform: Optional[Callable] = None,
reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
super().__init__(reduction=reduction, get_not_nans=get_not_nans)

self.y_transform = y_transform
self.y_pred_transform = y_pred_transform

def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
"""
Args:
y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
y_pred: second sample (e.g., the reconstructed image). It has similar shape as y.
"""

# Beta and Gamma are not calculated since torch.mean is used at return
beta = 1.0
gamma = 2.0

if self.y_transform is not None:
y = self.y_transform(y)

if self.y_pred_transform is not None:
y_pred = self.y_pred_transform(y_pred)

if y_pred.shape != y.shape:
raise ValueError(
f"y_pred and y shapes dont match after being processed by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
)

for d in range(len(y.shape) - 1, 1, -1):
y = y.squeeze(dim=d)
y_pred = y_pred.squeeze(dim=d)

y = y.view(y.shape[0], -1)
y_pred = y_pred.view(y_pred.shape[0], -1)

y_y = torch.mm(y, y.t())
y_pred_y_pred = torch.mm(y_pred, y_pred.t())
y_pred_y = torch.mm(y_pred, y.t())

y_y = y_y / y.shape[1]
y_pred_y_pred = y_pred_y_pred / y.shape[1]
y_pred_y = y_pred_y / y.shape[1]

# Ref. 1 Eq. 3 (found under Lemma 6)
return beta * (torch.mean(y_y) + torch.mean(y_pred_y_pred)) - gamma * torch.mean(y_pred_y)
47 changes: 47 additions & 0 deletions tests/test_compute_mmd_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import numpy as np
import torch
from parameterized import parameterized

from generative.metrics import MMD

TEST_CASES = [
[
{"y_transform": None, "y_pred_transform": None},
{"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])},
0.0,
],
[
{"y_transform": None, "y_pred_transform": None},
{"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])},
0.0,
],
]


class TestMMDMetric(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_results(self, input_param, input_data, expected_val):
results = MMD(**input_param)._compute_metric(**input_data)
np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4)

def test_if_inputs_different_shapes(self):
with self.assertRaises(ValueError):
MMD()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()