Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generative/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .fid import FID
from .mmd import MMD
from .ms_ssim import MSSSIM
146 changes: 146 additions & 0 deletions generative/metrics/fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =========================================================================
# Adapted from https://github.com/photosynthesis-team/piq
# which has the following license:
# https://github.com/photosynthesis-team/piq/blob/master/LICENSE

# Copyright 2023 photosynthesis-team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

from __future__ import annotations

import torch
from monai.metrics.metric import Metric


class FID(Metric):
"""
Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.
Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium."
https://arxiv.org/abs/1706.08500#. The inputs for this metric should be two groups of feature vectors (with format
(number images, number of features)) extracted from the a pretrained network.

Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet.
However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and
MedicalNet for 3D images). If the chosen model output is not a scalar, usually it is used a global spatial
average pooling.
"""

def __init__(self) -> None:
super().__init__()

def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):
return get_fid_score(y_pred, y)


def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
y = y.float()
y_pred = y_pred.float()

if y.ndimension() > 2:
raise ValueError(f"Inputs should have (number images, number of features) shape.")

mu_y_pred = torch.mean(y_pred, dim=0)
sigma_y_pred = _cov(y_pred, rowvar=False)
mu_y = torch.mean(y, dim=0)
sigma_y = _cov(y, rowvar=False)

return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)


def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
"""
Estimate a covariance matrix of the variables.

Args:
m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
and each column a single observation of all those variables.
rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
observations.
"""
if m.dim() < 2:
m = m.view(1, -1)

if not rowvar and m.size(0) != 1:
m = m.t()

fact = 1.0 / (m.size(1) - 1)
m = m - torch.mean(m, dim=1, keepdim=True)
mt = m.t()
return fact * m.matmul(mt).squeeze()


def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]:
"""
Square root of matrix using Newton-Schulz Iterative method. Based on:
https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py. Bechmark shown in:
https://github.com/photosynthesis-team/piq/issues/190#issuecomment-742039303

Args:
matrix: matrix or batch of matrices
num_iters: Number of iteration of the method

"""
dim = matrix.size(0)
norm_of_matrix = matrix.norm(p="fro")
y_matrix = matrix.div(norm_of_matrix)
i_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)
z_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)

s_matrix = torch.empty_like(matrix)
error = torch.empty(1, device=matrix.device, dtype=matrix.dtype)

for _ in range(num_iters):
T = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
y_matrix = y_matrix.mm(T)
z_matrix = T.mm(z_matrix)

s_matrix = y_matrix * torch.sqrt(norm_of_matrix)

norm_of_matrix = torch.norm(matrix)
error = matrix - torch.mm(s_matrix, s_matrix)
error = torch.norm(error) / norm_of_matrix

if torch.isclose(error, torch.tensor([0.0], device=error.device, dtype=error.dtype), atol=1e-5):
break

return s_matrix, error


def compute_frechet_distance(
mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6
) -> torch.Tensor:
"""The Frechet distance between multivariate normal distributions."""
diff = mu_x - mu_y
covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y))

# If calculation produces singular product, epsilon is added to diagonal of cov estimates
if not torch.isfinite(covmean).all():
offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon
covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset))

tr_covmean = torch.trace(covmean)
return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean
34 changes: 34 additions & 0 deletions tests/test_compute_fid_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import numpy as np
import torch

from generative.metrics import FID


class TestMMDMetric(unittest.TestCase):
def test_results(self):
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
results = FID()(x, y)
np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4)

def test_input_dimensions(self):
with self.assertRaises(ValueError):
FID()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()