diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 11a5682de6..d2a3ca4a53 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -572,8 +572,8 @@ class SmartCacheDataset(CacheDataset): 4. Call `shutdown()` when training ends. Note: - This replacement will not work if set the `multiprocessing_context` of DataLoader to `spawn` - or on windows(the default multiprocessing method is `spawn`) and set `num_workers` greater than 0 . + This replacement will not work if setting the `multiprocessing_context` of DataLoader to `spawn` + or on windows(the default multiprocessing method is `spawn`) and setting `num_workers` greater than 0. """ diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index dbca70bf25..2273b9ee89 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch +from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import compute_roc_auc from monai.utils import Average, exact_version, optional_import +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") EpochMetric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "EpochMetric") @@ -71,9 +73,32 @@ def _compute_fn(pred, label): average=Average(average), ) + self._is_reduced: bool = False super().__init__( compute_fn=_compute_fn, output_transform=output_transform, check_compute_fn=False, device=device, ) + + def compute(self) -> Any: + _prediction_tensor = torch.cat(self._predictions, dim=0) + _target_tensor = torch.cat(self._targets, dim=0) + + ws = idist.get_world_size() + if ws > 1 and not self._is_reduced: + # All gather across all processes + _prediction_tensor = evenly_divisible_all_gather(_prediction_tensor) + _target_tensor = evenly_divisible_all_gather(_target_tensor) + self._is_reduced = True + + result: torch.Tensor = torch.zeros(1) + if idist.get_rank() == 0: + # Run compute_fn on zero rank only + result = self.compute_fn(_prediction_tensor, _target_tensor) + + if ws > 1: + # broadcast result to all processes + result = idist.broadcast(result, src=0) + + return result.item() if torch.is_tensor(result) else result diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 2165ad8860..d0179e7f49 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -62,6 +62,9 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: Args: data: source tensor to pad and execute all_gather in distributed data parallel. + Note: + The input data on different ranks must have exactly same `dtype`. + """ if not isinstance(data, torch.Tensor): raise ValueError("input data must be PyTorch Tensor.") diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 825b172064..c5cf44162c 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -31,12 +31,12 @@ def test_compute(self): auc_metric.update([y_pred, y]) if dist.get_rank() == 1: - y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=device) - y = torch.tensor([[0], [1]], device=device) + y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5], [0.3, 0.4]], device=device) + y = torch.tensor([[0], [1], [1]], device=device) auc_metric.update([y_pred, y]) result = auc_metric.compute() - np.testing.assert_allclose(0.75, result) + np.testing.assert_allclose(0.66667, result, rtol=1e-4) if __name__ == "__main__":