-
Notifications
You must be signed in to change notification settings - Fork 1.4k
6676 port generative metrics #6836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wyli
merged 18 commits into
Project-MONAI:dev
from
marksgraham:6676_port_generative_metrics
Aug 16, 2023
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
a74e995
Adds FID metric
marksgraham c98c4e7
Adds MMD metric
marksgraham 4bbfa2a
Optional imports
marksgraham 917e504
Adds ms-ssim metric
marksgraham 3f9a492
Mypy fixes
marksgraham b7734c6
DCO Remediation Commit for Mark Graham <markgraham539@gmail.com>
marksgraham edf2746
Undo minor change for DCO commit
marksgraham fd4f5d7
Merge branch 'dev' into 6676_port_generative_metrics
marksgraham 3c81128
Update monai/metrics/fid.py
marksgraham 78cc8cc
Update docs
marksgraham fb8ee08
Updates docstring
marksgraham 5efe5b0
Updates MMD calculation to match original paper, and provide just a s…
marksgraham a474743
Merge branch 'Project-MONAI:dev' into 6676_port_generative_metrics
marksgraham a5cdba9
Make variables lowercase
marksgraham 20de356
Fix variable name
marksgraham 3920644
Add check for batch_size=1
marksgraham 4b0ec8f
Fixes formatting error
marksgraham 66833c4
Merge branch 'dev' into 6676_port_generative_metrics
marksgraham File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.metrics.metric import Metric | ||
| from monai.utils import optional_import | ||
|
|
||
| scipy, _ = optional_import("scipy") | ||
|
|
||
|
|
||
| class FIDMetric(Metric): | ||
| """ | ||
| Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors. | ||
| Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." | ||
| https://arxiv.org/abs/1706.08500. The inputs for this metric should be two groups of feature vectors (with format | ||
| (number images, number of features)) extracted from a pretrained network. | ||
|
|
||
| Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet. | ||
| However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and | ||
| MedicalNet for 3D images). If the chosen model output is not a scalar, a global spatia average pooling should be | ||
| used. | ||
| """ | ||
|
|
||
| def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| return get_fid_score(y_pred, y) | ||
|
|
||
|
|
||
| def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """Computes the FID score metric on a batch of feature vectors. | ||
|
|
||
| Args: | ||
| y_pred: feature vectors extracted from a pretrained network run on generated images. | ||
| y: feature vectors extracted from a pretrained network run on images from the real data distribution. | ||
| """ | ||
| y = y.double() | ||
| y_pred = y_pred.double() | ||
|
|
||
| if y.ndimension() > 2: | ||
| raise ValueError("Inputs should have (number images, number of features) shape.") | ||
|
|
||
| mu_y_pred = torch.mean(y_pred, dim=0) | ||
| sigma_y_pred = _cov(y_pred, rowvar=False) | ||
| mu_y = torch.mean(y, dim=0) | ||
| sigma_y = _cov(y, rowvar=False) | ||
|
|
||
| return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) | ||
|
|
||
|
|
||
| def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor: | ||
| """ | ||
| Estimate a covariance matrix of the variables. | ||
|
|
||
| Args: | ||
| input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, | ||
| and each column a single observation of all those variables. | ||
| rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns. | ||
| Otherwise, the relationship is transposed: each column represents a variable, while the rows contain | ||
| observations. | ||
| """ | ||
| if input_data.dim() < 2: | ||
| input_data = input_data.view(1, -1) | ||
|
|
||
| if not rowvar and input_data.size(0) != 1: | ||
| input_data = input_data.t() | ||
|
|
||
| factor = 1.0 / (input_data.size(1) - 1) | ||
| input_data = input_data - torch.mean(input_data, dim=1, keepdim=True) | ||
| return factor * input_data.matmul(input_data.t()).squeeze() | ||
|
|
||
|
|
||
| def _sqrtm(input_data: torch.Tensor) -> torch.Tensor: | ||
| """Compute the square root of a matrix.""" | ||
| scipy_res, _ = scipy.linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False) | ||
| return torch.from_numpy(scipy_res) | ||
|
|
||
|
|
||
| def compute_frechet_distance( | ||
| mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6 | ||
| ) -> torch.Tensor: | ||
| """The Frechet distance between multivariate normal distributions.""" | ||
| diff = mu_x - mu_y | ||
|
|
||
| covmean = _sqrtm(sigma_x.mm(sigma_y)) | ||
|
|
||
| # Product might be almost singular | ||
| if not torch.isfinite(covmean).all(): | ||
| print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates") | ||
| offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon | ||
| covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset)) | ||
|
|
||
| # Numerical error might give slight imaginary component | ||
| if torch.is_complex(covmean): | ||
| if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3): | ||
| raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.") | ||
| covmean = covmean.real | ||
|
|
||
| tr_covmean = torch.trace(covmean) | ||
| return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from monai.metrics.metric import Metric | ||
|
|
||
|
|
||
| class MMDMetric(Metric): | ||
| """ | ||
| Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two | ||
| distributions. It is a non-negative metric where a smaller value indicates a closer match between the two | ||
| distributions. | ||
|
|
||
| Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773. | ||
|
|
||
| Args: | ||
| y_mapping: Callable to transform the y tensors before computing the metric. It is usually a Gaussian or Laplace | ||
| filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a | ||
| feature extractor or an Identity function., e.g. `y_mapping = lambda x: x.square()`. | ||
| """ | ||
|
|
||
| def __init__(self, y_mapping: Callable | None = None) -> None: | ||
| super().__init__() | ||
| self.y_mapping = y_mapping | ||
|
|
||
| def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: | ||
| return compute_mmd(y, y_pred, self.y_mapping) | ||
|
|
||
|
|
||
| def compute_mmd(y: torch.Tensor, y_pred: torch.Tensor, y_mapping: Callable | None) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. | ||
| y_pred: second sample (e.g., the reconstructed image). It has similar shape as y. | ||
| y_mapping: Callable to transform the y tensors before computing the metric. | ||
| """ | ||
| if y_pred.shape[0] == 1 or y.shape[0] == 1: | ||
| raise ValueError("MMD metric requires at least two samples in y and y_pred.") | ||
|
|
||
| if y_mapping is not None: | ||
| y = y_mapping(y) | ||
| y_pred = y_mapping(y_pred) | ||
|
|
||
| if y_pred.shape != y.shape: | ||
| raise ValueError( | ||
| "y_pred and y shapes dont match after being processed " | ||
| f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}" | ||
| ) | ||
|
|
||
| for d in range(len(y.shape) - 1, 1, -1): | ||
| y = y.squeeze(dim=d) | ||
| y_pred = y_pred.squeeze(dim=d) | ||
|
|
||
| y = y.view(y.shape[0], -1) | ||
| y_pred = y_pred.view(y_pred.shape[0], -1) | ||
|
|
||
| y_y = torch.mm(y, y.t()) | ||
| y_pred_y_pred = torch.mm(y_pred, y_pred.t()) | ||
| y_pred_y = torch.mm(y_pred, y.t()) | ||
|
|
||
| m = y.shape[0] | ||
| n = y_pred.shape[0] | ||
|
|
||
| # Ref. 1 Eq. 3 (found under Lemma 6) | ||
marksgraham marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # term 1 | ||
| c1 = 1 / (m * (m - 1)) | ||
| a = torch.sum(y_y - torch.diag(torch.diagonal(y_y))) | ||
|
|
||
| # term 2 | ||
| c2 = 1 / (n * (n - 1)) | ||
| b = torch.sum(y_pred_y_pred - torch.diag(torch.diagonal(y_pred_y_pred))) | ||
|
|
||
| # term 3 | ||
| c3 = 2 / (m * n) | ||
| c = torch.sum(y_pred_y) | ||
|
|
||
| mmd = c1 * a + c2 * b - c3 * c | ||
| return mmd | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.