From 0a6a79fd4a9d9f0b2548ce1b7515cd9a931552d8 Mon Sep 17 00:00:00 2001 From: Warvito Date: Sun, 30 Oct 2022 14:41:27 +0000 Subject: [PATCH 1/7] [WIP] Add FID --- generative/metrics/fid.py | 174 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 generative/metrics/fid.py diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py new file mode 100644 index 00000000..845e3936 --- /dev/null +++ b/generative/metrics/fid.py @@ -0,0 +1,174 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import torch +from monai.metrics import CumulativeIterationMetric +from monai.utils import MetricReduction +from torchvision.models import Inception_V3_Weights, inception_v3 + +RADIMAGENET_URL = "https://drive.google.com/uc?id=1p0q9AhG3rufIaaUE1jc2okpS8sdwN6PU" +RADIMAGENET_WEIGHTS = "RadImageNet-InceptionV3_notop.h5" + + +# TODO: get a better name for parameters +# TODO: Transform radimagenet's Keras weight to Torch weights following https://github.com/BMEII-AI/RadImageNet/issues/3 +# TODO: Create Mednet3D +class FID(CumulativeIterationMetric): + """ + Frechet Inception Distance (FID). FID can compare two data distributions with different number of samples. + But dimensionalities should match, otherwise it won't be possible to correctly compute statistics. Based on: + Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." + https://arxiv.org/abs/1706.08500# + + Args: + reduction: + extract_features: + feature_extractor: + """ + + def __init__( + self, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + extract_features: bool = True, + feature_extractor: str = "imagenet", + ) -> None: + super().__init__() + self.reduction = reduction + self.feature_extractor = feature_extractor + + # TODO: Download pretrained network. + self.network = None + if extract_features: + if feature_extractor == "imagenet:": + self.network = inception_v3(Inception_V3_Weights.IMAGENET1K_V1) + elif feature_extractor == "radimagenet": + self.network = inception_v3() + elif feature_extractor == "medicalnet": + pass + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): + """ + Args: + y_pred: + y: + """ + pass + + def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + """ + Args: + reduction: + + Returns: + """ + pass + + +def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Square root of matrix using Newton-Schulz Iterative method. Based on: + https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py + + Args: + matrix: matrix or batch of matrices + num_iters: Number of iteration of the method + + Returns: + Square root of matrix + Error + """ + dim = matrix.size(0) + norm_of_matrix = matrix.norm(p="fro") + y_matrix = matrix.div(norm_of_matrix) + i_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype) + z_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype) + + s_matrix = torch.empty_like(matrix) + error = torch.empty(1, device=matrix.device, dtype=matrix.dtype) + + for _ in range(num_iters): + T = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix)) + y_matrix = y_matrix.mm(T) + z_matrix = T.mm(z_matrix) + + s_matrix = y_matrix * torch.sqrt(norm_of_matrix) + + norm_of_matrix = torch.norm(matrix) + error = matrix - torch.mm(s_matrix, s_matrix) + error = torch.norm(error) / norm_of_matrix + + if torch.isclose(error, torch.tensor([0.0], device=error.device, dtype=error.dtype), atol=1e-5): + break + + return s_matrix, error + + +def _compute_statistics(samples: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculates the statistics used by FID + + Args: + samples: Low-dimension representation of image set. + Shape (N_samples, dims) and dtype: np.float32 in range 0 - 1 + + Returns: + mu: mean over all activations from the encoder. + sigma: covariance matrix over all activations from the encoder. + """ + mu = torch.mean(samples, dim=0) + + # Estimate a covariance matrix + if samples.dim() < 2: + samples = samples.view(1, -1) + + if samples.size(0) != 1: + samples = samples.t() + + fact = 1.0 / (samples.size(1) - 1) + samples = samples - torch.mean(samples, dim=1, keepdim=True) + samplest = samples.t() + sigma = fact * samples.matmul(samplest).squeeze() + + return mu, sigma + + +def compute_fid_from_features(x_features: torch.Tensor, y_features: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Fits multivariate Gaussians, then computes FID. + + Args: + x_features: Samples from data distribution. Shape :math:`(N_x, D)` + y_features: Samples from data distribution. Shape :math:`(N_y, D)` + eps: + + Returns: + The Frechet Distance. + """ + + mu_x, sigma_x = _compute_statistics(x_features) + mu_y, sigma_y = _compute_statistics(y_features) + + # The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1) + # and X_y ~ N(mu_2, sigm_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)). + diff = mu_x - mu_y + covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y)) + + # Product might be almost singular + if not torch.isfinite(covmean).all(): + offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * eps + covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset)) + + tr_covmean = torch.trace(covmean) + score = diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean + + return score From 00d57b51944729c4f426f0437602f28f3aa8bbf2 Mon Sep 17 00:00:00 2001 From: Warvito Date: Thu, 3 Nov 2022 21:44:36 +0000 Subject: [PATCH 2/7] [WIP] Add FID --- generative/metrics/fid.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index 845e3936..db56738a 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -13,6 +13,7 @@ import torch from monai.metrics import CumulativeIterationMetric +from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction from torchvision.models import Inception_V3_Weights, inception_v3 @@ -23,6 +24,7 @@ # TODO: get a better name for parameters # TODO: Transform radimagenet's Keras weight to Torch weights following https://github.com/BMEII-AI/RadImageNet/issues/3 # TODO: Create Mednet3D +# TODO: Remove CumulativeIterationMetric class FID(CumulativeIterationMetric): """ Frechet Inception Distance (FID). FID can compare two data distributions with different number of samples. @@ -49,29 +51,45 @@ def __init__( # TODO: Download pretrained network. self.network = None if extract_features: - if feature_extractor == "imagenet:": - self.network = inception_v3(Inception_V3_Weights.IMAGENET1K_V1) + if feature_extractor == "imagenet": + weights = Inception_V3_Weights.IMAGENET1K_V1 + self.network = inception_v3(weights=weights).eval() elif feature_extractor == "radimagenet": - self.network = inception_v3() + self.network = inception_v3().eval() elif feature_extractor == "medicalnet": pass - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore """ Args: y_pred: y: """ - pass + # check dimension + dims = y_pred.ndimension() + if dims < 2: + raise ValueError("y_pred should have at least two dimensions.") + + # TODO: GET features + y_pred_features = None + y_features = None + + return compute_fid_from_features(y_pred_features, y_features) def aggregate(self, reduction: Union[MetricReduction, str, None] = None): """ Args: - reduction: - - Returns: + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. """ - pass + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("The data to aggregate must be a PyTorch Tensor.") + + # do metric reduction + f, not_nans = do_metric_reduction(data, reduction or self.reduction) + return (f, not_nans) if self.get_not_nans else f def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: From 1cf4ac178c5b3aaca7a56bffb3a08728bc3a4dfc Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 21 Jan 2023 11:44:51 +0000 Subject: [PATCH 3/7] Add reference copyright information Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/fid.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index db56738a..c412af71 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -9,10 +9,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +# ========================================================================= +# Adapted from https://github.com/photosynthesis-team/piq +# which has the following license: +# https://github.com/photosynthesis-team/piq/blob/master/LICENSE + +# Copyright 2023 photosynthesis-team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations import torch -from monai.metrics import CumulativeIterationMetric +from monai.metrics.metric import Metric from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction from torchvision.models import Inception_V3_Weights, inception_v3 @@ -24,8 +44,7 @@ # TODO: get a better name for parameters # TODO: Transform radimagenet's Keras weight to Torch weights following https://github.com/BMEII-AI/RadImageNet/issues/3 # TODO: Create Mednet3D -# TODO: Remove CumulativeIterationMetric -class FID(CumulativeIterationMetric): +class FID(Metric): """ Frechet Inception Distance (FID). FID can compare two data distributions with different number of samples. But dimensionalities should match, otherwise it won't be possible to correctly compute statistics. Based on: @@ -40,7 +59,7 @@ class FID(CumulativeIterationMetric): def __init__( self, - reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, extract_features: bool = True, feature_extractor: str = "imagenet", ) -> None: @@ -76,7 +95,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor return compute_fid_from_features(y_pred_features, y_features) - def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + def aggregate(self, reduction: MetricReduction | str | None = None): """ Args: reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, @@ -92,7 +111,7 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): return (f, not_nans) if self.get_not_nans else f -def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: +def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]: """ Square root of matrix using Newton-Schulz Iterative method. Based on: https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py @@ -131,7 +150,7 @@ def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> Tuple[to return s_matrix, error -def _compute_statistics(samples: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def _compute_statistics(samples: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Calculates the statistics used by FID From cb86f3146d4234e80e11bd0de0542469d81b34d2 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 22 Jan 2023 13:45:26 +0000 Subject: [PATCH 4/7] Refactor FID metric Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/fid.py | 115 +++++++++++++------------------------- 1 file changed, 38 insertions(+), 77 deletions(-) diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index c412af71..0b872f13 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -33,17 +33,9 @@ import torch from monai.metrics.metric import Metric -from monai.metrics.utils import do_metric_reduction -from monai.utils import MetricReduction from torchvision.models import Inception_V3_Weights, inception_v3 -RADIMAGENET_URL = "https://drive.google.com/uc?id=1p0q9AhG3rufIaaUE1jc2okpS8sdwN6PU" -RADIMAGENET_WEIGHTS = "RadImageNet-InceptionV3_notop.h5" - -# TODO: get a better name for parameters -# TODO: Transform radimagenet's Keras weight to Torch weights following https://github.com/BMEII-AI/RadImageNet/issues/3 -# TODO: Create Mednet3D class FID(Metric): """ Frechet Inception Distance (FID). FID can compare two data distributions with different number of samples. @@ -51,67 +43,54 @@ class FID(Metric): Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." https://arxiv.org/abs/1706.08500# - Args: - reduction: - extract_features: - feature_extractor: """ def __init__( self, - reduction: MetricReduction | str = MetricReduction.MEAN, - extract_features: bool = True, - feature_extractor: str = "imagenet", + feature_extractor_type: str | None = "imagenet", ) -> None: super().__init__() - self.reduction = reduction - self.feature_extractor = feature_extractor + self.feature_extractor_type = feature_extractor_type + self.feature_extractor = None - # TODO: Download pretrained network. - self.network = None - if extract_features: - if feature_extractor == "imagenet": + if feature_extractor_type: + if feature_extractor_type == "imagenet": + # TODO: Add feature extractor weights = Inception_V3_Weights.IMAGENET1K_V1 - self.network = inception_v3(weights=weights).eval() - elif feature_extractor == "radimagenet": - self.network = inception_v3().eval() - elif feature_extractor == "medicalnet": - pass - - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore - """ - Args: - y_pred: - y: - """ - # check dimension - dims = y_pred.ndimension() - if dims < 2: + self.feature_extractor = inception_v3(weights=weights).eval() + elif feature_extractor_type == "radimagenet": + # TODO: Add feature extractor + self.feature_extractor = inception_v3().eval() + elif feature_extractor_type == "medicalnet": + # TODO: Add feature extractor + self.feature_extractor = inception_v3().eval() + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + if self.feature_extractor_type in ["radimagenet", "imagenet"] and ( + y_pred.ndimension() != 4 or y.ndimension() != 4 + ): + raise ValueError("FID requires RGB images.") + + if self.feature_extractor_type == "medicalnet" and (y_pred.ndimension() != 5 or y.ndimension() != 5): + raise ValueError("FID requires RGB images.") + + if y_pred.ndimension() < 2 or y.ndimension() < 2: raise ValueError("y_pred should have at least two dimensions.") - # TODO: GET features - y_pred_features = None - y_features = None - - return compute_fid_from_features(y_pred_features, y_features) + if self.feature_extractor: + y_pred_features = self.feature_extractor.features(y_pred) + y_features = self.feature_extractor.features(y) + else: + y_pred_features = y_pred + y_features = y - def aggregate(self, reduction: MetricReduction | str | None = None): - """ - Args: - reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. - """ - data = self.get_buffer() - if not isinstance(data, torch.Tensor): - raise ValueError("The data to aggregate must be a PyTorch Tensor.") + mu_y_pred, sigma_y_pred = compute_statistics(y_pred_features) + mu_y, sigma_y = compute_statistics(y_features) - # do metric reduction - f, not_nans = do_metric_reduction(data, reduction or self.reduction) - return (f, not_nans) if self.get_not_nans else f + return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) -def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]: +def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> torch.Tensor: """ Square root of matrix using Newton-Schulz Iterative method. Based on: https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py @@ -120,9 +99,6 @@ def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[to matrix: matrix or batch of matrices num_iters: Number of iteration of the method - Returns: - Square root of matrix - Error """ dim = matrix.size(0) norm_of_matrix = matrix.norm(p="fro") @@ -150,7 +126,7 @@ def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[to return s_matrix, error -def _compute_statistics(samples: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def compute_statistics(samples: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Calculates the statistics used by FID @@ -179,24 +155,11 @@ def _compute_statistics(samples: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return mu, sigma -def compute_fid_from_features(x_features: torch.Tensor, y_features: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: +def compute_frechet_distance(mu_x, sigma_x, mu_y, sigma_y, eps=1e-6): """ - Fits multivariate Gaussians, then computes FID. - - Args: - x_features: Samples from data distribution. Shape :math:`(N_x, D)` - y_features: Samples from data distribution. Shape :math:`(N_y, D)` - eps: - - Returns: - The Frechet Distance. + The Frechet distance between two multivariate Gaussians """ - mu_x, sigma_x = _compute_statistics(x_features) - mu_y, sigma_y = _compute_statistics(y_features) - - # The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1) - # and X_y ~ N(mu_2, sigm_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)). diff = mu_x - mu_y covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y)) @@ -206,6 +169,4 @@ def compute_fid_from_features(x_features: torch.Tensor, y_features: torch.Tensor covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset)) tr_covmean = torch.trace(covmean) - score = diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean - - return score + return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean From 0f5753b0c6af2937027d93994b2125ddccce0675 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 22 Jan 2023 14:01:23 +0000 Subject: [PATCH 5/7] Add medicalnet feature extractor Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/fid.py | 58 +++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index 0b872f13..6cac0e23 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -32,6 +32,7 @@ from __future__ import annotations import torch +import torch.nn as nn from monai.metrics.metric import Metric from torchvision.models import Inception_V3_Weights, inception_v3 @@ -48,6 +49,7 @@ class FID(Metric): def __init__( self, feature_extractor_type: str | None = "imagenet", + verbose: bool = False, ) -> None: super().__init__() self.feature_extractor_type = feature_extractor_type @@ -55,15 +57,16 @@ def __init__( if feature_extractor_type: if feature_extractor_type == "imagenet": - # TODO: Add feature extractor weights = Inception_V3_Weights.IMAGENET1K_V1 self.feature_extractor = inception_v3(weights=weights).eval() elif feature_extractor_type == "radimagenet": - # TODO: Add feature extractor - self.feature_extractor = inception_v3().eval() + raise NotImplementedError("Radimage net not implemented yet") elif feature_extractor_type == "medicalnet": - # TODO: Add feature extractor - self.feature_extractor = inception_v3().eval() + self.feature_extractor = MedicalNetFeatureExtractor( + net="medicalnet_resnet10_23datasets", verbose=verbose + ) + + self.feature_extractor.eval() def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): if self.feature_extractor_type in ["radimagenet", "imagenet"] and ( @@ -155,7 +158,9 @@ def compute_statistics(samples: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso return mu, sigma -def compute_frechet_distance(mu_x, sigma_x, mu_y, sigma_y, eps=1e-6): +def compute_frechet_distance( + mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: """ The Frechet distance between two multivariate Gaussians """ @@ -170,3 +175,44 @@ def compute_frechet_distance(mu_x, sigma_x, mu_y, sigma_y, eps=1e-6): tr_covmean = torch.trace(covmean) return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean + + +class MedicalNetFeatureExtractor(nn.Module): + """ + Component to compute the features with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D + Medical Image Analysis". This class uses torch Hub to download the networks from "Warvito/MedicalNet-models". + + Args: + net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} + Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__( + self, + net: str = "medicalnet_resnet10_23datasets", + verbose: bool = False, + ) -> None: + super().__init__() + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) + self.eval() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute features using MedicalNet 3D networks. + + Args: + input: 3D input tensor with shape BCDHW. + """ + x = medicalnet_intensity_normalisation(x) + return self.model.forward(x) + + +def medicalnet_intensity_normalisation(volume): + """Intensity normalisation based on + https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133 + """ + mean = volume.mean() + std = volume.std() + return (volume - mean) / std From 480d3e737841b6fff251f0417fdf6a22e725072a Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 22 Jan 2023 14:22:26 +0000 Subject: [PATCH 6/7] Remove feature extractors from implementation Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/fid.py | 98 ++++++--------------------------------- 1 file changed, 13 insertions(+), 85 deletions(-) diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index 6cac0e23..3dd136cc 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -32,68 +32,37 @@ from __future__ import annotations import torch -import torch.nn as nn from monai.metrics.metric import Metric -from torchvision.models import Inception_V3_Weights, inception_v3 class FID(Metric): """ - Frechet Inception Distance (FID). FID can compare two data distributions with different number of samples. - But dimensionalities should match, otherwise it won't be possible to correctly compute statistics. Based on: + Frechet Inception Distance (FID). The FID calculates the distance between two groups of feature vectors. Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." https://arxiv.org/abs/1706.08500# - """ - def __init__( - self, - feature_extractor_type: str | None = "imagenet", - verbose: bool = False, - ) -> None: + def __init__(self) -> None: super().__init__() - self.feature_extractor_type = feature_extractor_type - self.feature_extractor = None - - if feature_extractor_type: - if feature_extractor_type == "imagenet": - weights = Inception_V3_Weights.IMAGENET1K_V1 - self.feature_extractor = inception_v3(weights=weights).eval() - elif feature_extractor_type == "radimagenet": - raise NotImplementedError("Radimage net not implemented yet") - elif feature_extractor_type == "medicalnet": - self.feature_extractor = MedicalNetFeatureExtractor( - net="medicalnet_resnet10_23datasets", verbose=verbose - ) - - self.feature_extractor.eval() def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): - if self.feature_extractor_type in ["radimagenet", "imagenet"] and ( - y_pred.ndimension() != 4 or y.ndimension() != 4 - ): - raise ValueError("FID requires RGB images.") + return get_fid_score(y_pred, y) - if self.feature_extractor_type == "medicalnet" and (y_pred.ndimension() != 5 or y.ndimension() != 5): - raise ValueError("FID requires RGB images.") - if y_pred.ndimension() < 2 or y.ndimension() < 2: - raise ValueError("y_pred should have at least two dimensions.") +def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + y = y.float() + y_pred = y_pred.float() - if self.feature_extractor: - y_pred_features = self.feature_extractor.features(y_pred) - y_features = self.feature_extractor.features(y) - else: - y_pred_features = y_pred - y_features = y + if y.shape != y_pred.shape: + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") - mu_y_pred, sigma_y_pred = compute_statistics(y_pred_features) - mu_y, sigma_y = compute_statistics(y_features) + mu_y_pred, sigma_y_pred = compute_statistics(y_pred) + mu_y, sigma_y = compute_statistics(y) - return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) + return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) -def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> torch.Tensor: +def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]: """ Square root of matrix using Newton-Schulz Iterative method. Based on: https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py @@ -162,7 +131,7 @@ def compute_frechet_distance( mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, eps: float = 1e-6 ) -> torch.Tensor: """ - The Frechet distance between two multivariate Gaussians + The Frechet distance between two multivariate Gaussians. """ diff = mu_x - mu_y @@ -175,44 +144,3 @@ def compute_frechet_distance( tr_covmean = torch.trace(covmean) return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean - - -class MedicalNetFeatureExtractor(nn.Module): - """ - Component to compute the features with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D - Medical Image Analysis". This class uses torch Hub to download the networks from "Warvito/MedicalNet-models". - - Args: - net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} - Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. - verbose: if false, mute messages from torch Hub load function. - """ - - def __init__( - self, - net: str = "medicalnet_resnet10_23datasets", - verbose: bool = False, - ) -> None: - super().__init__() - torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) - self.eval() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Compute features using MedicalNet 3D networks. - - Args: - input: 3D input tensor with shape BCDHW. - """ - x = medicalnet_intensity_normalisation(x) - return self.model.forward(x) - - -def medicalnet_intensity_normalisation(volume): - """Intensity normalisation based on - https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133 - """ - mean = volume.mean() - std = volume.std() - return (volume - mean) / std From 96b5b31c22d4b60ea5b00c2944632b23860d6f52 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sun, 22 Jan 2023 18:24:08 +0000 Subject: [PATCH 7/7] Add tests Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/__init__.py | 1 + generative/metrics/fid.py | 88 ++++++++++++++++---------------- tests/test_compute_fid_metric.py | 34 ++++++++++++ 3 files changed, 79 insertions(+), 44 deletions(-) create mode 100644 tests/test_compute_fid_metric.py diff --git a/generative/metrics/__init__.py b/generative/metrics/__init__.py index fec106c1..bd4b9acc 100644 --- a/generative/metrics/__init__.py +++ b/generative/metrics/__init__.py @@ -9,5 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .fid import FID from .mmd import MMD from .ms_ssim import MSSSIM diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index 3dd136cc..afae75a1 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -37,9 +37,15 @@ class FID(Metric): """ - Frechet Inception Distance (FID). The FID calculates the distance between two groups of feature vectors. Based on: - Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." - https://arxiv.org/abs/1706.08500# + Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors. + Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." + https://arxiv.org/abs/1706.08500#. The inputs for this metric should be two groups of feature vectors (with format + (number images, number of features)) extracted from the a pretrained network. + + Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet. + However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and + MedicalNet for 3D images). If the chosen model output is not a scalar, usually it is used a global spatial + average pooling. """ def __init__(self) -> None: @@ -53,19 +59,45 @@ def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y = y.float() y_pred = y_pred.float() - if y.shape != y_pred.shape: - raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + if y.ndimension() > 2: + raise ValueError(f"Inputs should have (number images, number of features) shape.") - mu_y_pred, sigma_y_pred = compute_statistics(y_pred) - mu_y, sigma_y = compute_statistics(y) + mu_y_pred = torch.mean(y_pred, dim=0) + sigma_y_pred = _cov(y_pred, rowvar=False) + mu_y = torch.mean(y, dim=0) + sigma_y = _cov(y, rowvar=False) return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) +def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor: + """ + Estimate a covariance matrix of the variables. + + Args: + m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, + and each column a single observation of all those variables. + rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns. + Otherwise, the relationship is transposed: each column represents a variable, while the rows contain + observations. + """ + if m.dim() < 2: + m = m.view(1, -1) + + if not rowvar and m.size(0) != 1: + m = m.t() + + fact = 1.0 / (m.size(1) - 1) + m = m - torch.mean(m, dim=1, keepdim=True) + mt = m.t() + return fact * m.matmul(mt).squeeze() + + def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]: """ Square root of matrix using Newton-Schulz Iterative method. Based on: - https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py + https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py. Bechmark shown in: + https://github.com/photosynthesis-team/piq/issues/190#issuecomment-742039303 Args: matrix: matrix or batch of matrices @@ -98,48 +130,16 @@ def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[to return s_matrix, error -def compute_statistics(samples: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Calculates the statistics used by FID - - Args: - samples: Low-dimension representation of image set. - Shape (N_samples, dims) and dtype: np.float32 in range 0 - 1 - - Returns: - mu: mean over all activations from the encoder. - sigma: covariance matrix over all activations from the encoder. - """ - mu = torch.mean(samples, dim=0) - - # Estimate a covariance matrix - if samples.dim() < 2: - samples = samples.view(1, -1) - - if samples.size(0) != 1: - samples = samples.t() - - fact = 1.0 / (samples.size(1) - 1) - samples = samples - torch.mean(samples, dim=1, keepdim=True) - samplest = samples.t() - sigma = fact * samples.matmul(samplest).squeeze() - - return mu, sigma - - def compute_frechet_distance( - mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, eps: float = 1e-6 + mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6 ) -> torch.Tensor: - """ - The Frechet distance between two multivariate Gaussians. - """ - + """The Frechet distance between multivariate normal distributions.""" diff = mu_x - mu_y covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y)) - # Product might be almost singular + # If calculation produces singular product, epsilon is added to diagonal of cov estimates if not torch.isfinite(covmean).all(): - offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * eps + offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset)) tr_covmean = torch.trace(covmean) diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py new file mode 100644 index 00000000..3cbc6180 --- /dev/null +++ b/tests/test_compute_fid_metric.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch + +from generative.metrics import FID + + +class TestMMDMetric(unittest.TestCase): + def test_results(self): + x = torch.Tensor([[1, 2], [1, 2], [1, 2]]) + y = torch.Tensor([[2, 2], [1, 2], [1, 2]]) + results = FID()(x, y) + np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4) + + def test_input_dimensions(self): + with self.assertRaises(ValueError): + FID()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + + +if __name__ == "__main__": + unittest.main()