Skip to content
4 changes: 2 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ class SmartCacheDataset(CacheDataset):
4. Call `shutdown()` when training ends.

Note:
This replacement will not work if set the `multiprocessing_context` of DataLoader to `spawn`
or on windows(the default multiprocessing method is `spawn`) and set `num_workers` greater than 0 .
This replacement will not work if setting the `multiprocessing_context` of DataLoader to `spawn`
or on windows(the default multiprocessing method is `spawn`) and setting `num_workers` greater than 0.

"""

Expand Down
27 changes: 26 additions & 1 deletion monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

import torch

from monai.handlers.utils import evenly_divisible_all_gather
from monai.metrics import compute_roc_auc
from monai.utils import Average, exact_version, optional_import

idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
EpochMetric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "EpochMetric")


Expand Down Expand Up @@ -71,9 +73,32 @@ def _compute_fn(pred, label):
average=Average(average),
)

self._is_reduced: bool = False
super().__init__(
compute_fn=_compute_fn,
output_transform=output_transform,
check_compute_fn=False,
device=device,
)

def compute(self) -> Any:
_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# All gather across all processes
_prediction_tensor = evenly_divisible_all_gather(_prediction_tensor)
_target_tensor = evenly_divisible_all_gather(_target_tensor)
self._is_reduced = True

result: torch.Tensor = torch.zeros(1)
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
result = self.compute_fn(_prediction_tensor, _target_tensor)

if ws > 1:
# broadcast result to all processes
result = idist.broadcast(result, src=0)

return result.item() if torch.is_tensor(result) else result
3 changes: 3 additions & 0 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor:
Args:
data: source tensor to pad and execute all_gather in distributed data parallel.

Note:
The input data on different ranks must have exactly same `dtype`.

"""
if not isinstance(data, torch.Tensor):
raise ValueError("input data must be PyTorch Tensor.")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_handler_rocauc_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def test_compute(self):
auc_metric.update([y_pred, y])

if dist.get_rank() == 1:
y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=device)
y = torch.tensor([[0], [1]], device=device)
y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5], [0.3, 0.4]], device=device)
y = torch.tensor([[0], [1], [1]], device=device)
auc_metric.update([y_pred, y])

result = auc_metric.compute()
np.testing.assert_allclose(0.75, result)
np.testing.assert_allclose(0.66667, result, rtol=1e-4)


if __name__ == "__main__":
Expand Down