Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __call__(self, engine: Engine) -> None:
"""
_meta_data = self.batch_transform(engine.state.batch)
if Key.FILENAME_OR_OBJ in _meta_data:
# all gather filenames across ranks
_meta_data[Key.FILENAME_OR_OBJ] = string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])
# all gather filenames across ranks, only filenames are necessary
_meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])}
# all gather predictions across ranks
_engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output))

Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[st
# all gather across all ranks
_joined = delimiter.join(idist.all_gather(_joined))
else:
raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.")
raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.")

return _joined.split(delimiter)

Expand Down
7 changes: 6 additions & 1 deletion tests/test_handler_classification_saver_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def _train_func(engine, batch):
saver.attach(engine)

# rank 0 has 8 images, rank 1 has 10 images
data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))]}]
data = [
{
"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))],
"data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))],
}
]
engine.run(data, max_epochs=1)
filepath = os.path.join(tempdir, "predictions.csv")
if rank == 1:
Expand Down