Skip to content
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
46a7725
Add multi-gpu data analyzer
heyufan1995 Mar 17, 2023
14f1d4e
Merge branch 'Project-MONAI:dev' into dev
heyufan1995 Mar 20, 2023
6f4a0cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2023
f748927
Update multi-gpu data analyzer
heyufan1995 Mar 22, 2023
0b71a76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2023
d76feff
Merge branch 'Project-MONAI:dev' into dev
heyufan1995 Mar 23, 2023
0c28435
Merge branch 'Project-MONAI:dev' into dev
heyufan1995 Mar 24, 2023
e556ec2
Ignore non-json/yaml configs for BundleAlg
heyufan1995 Mar 24, 2023
033f14d
Skip multiprocessing init with cpu/single gpu
heyufan1995 Mar 27, 2023
a9689fa
Remove warning filter
heyufan1995 Mar 27, 2023
e2ebaac
Fix label reference error for data analyzer
heyufan1995 Apr 3, 2023
c39e904
delete a tmp file
heyufan1995 Apr 3, 2023
50d5633
Merge branch 'dev' into dev
wyli Apr 5, 2023
c536d44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2023
5e58527
[MONAI] code formatting
monai-bot Apr 5, 2023
7a0b73b
Update data_analyzer.py
wyli Apr 5, 2023
8b6c6be
Update data_analyzer.py
wyli Apr 5, 2023
6a4ea86
Merge branch 'dev' into dev
wyli Apr 5, 2023
64148e9
[MONAI] code formatting
monai-bot Apr 5, 2023
41645e1
Update data_analyzer.py
wyli Apr 5, 2023
79b3c6a
Merge branch 'dev' into dev
wyli Apr 5, 2023
d52e50f
change image_only to True for dataanalyzer
heyufan1995 Apr 11, 2023
1918f15
Merge branch 'dev' into dev
wyli Apr 11, 2023
35dc2bd
[MONAI] code formatting
monai-bot Apr 11, 2023
52d0d40
Add stats by case back to datastats.yaml
heyufan1995 Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 116 additions & 51 deletions monai/apps/auto3dseg/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

import numpy as np
import torch
from torch.multiprocessing import Manager, Process, set_start_method

from monai.apps.auto3dseg.transforms import EnsureSameShaped
from monai.apps.utils import get_logger
from monai.auto3dseg import SegSummarizer
from monai.auto3dseg.utils import datafold_read
from monai.bundle import config_parser
from monai.bundle.config_parser import ConfigParser
from monai.data import DataLoader, Dataset
from monai.data import DataLoader, Dataset, partition_dataset
from monai.data.utils import no_collation
from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd
from monai.utils import StrEnum, min_version, optional_import
Expand Down Expand Up @@ -61,8 +62,7 @@ class DataAnalyzer:
average: whether to average the statistical value across different image modalities.
do_ccp: apply the connected component algorithm to process the labels/images
device: a string specifying hardware (CUDA/CPU) utilized for the operations.
worker: number of workers to use for parallel processing. If device is cuda/GPU, worker has
to be 0.
worker: number of workers to use for loading datasets in each GPU/CPU sub-process.
image_key: a string that user specify for the image. The DataAnalyzer will look it up in the
datalist to locate the image files of the dataset.
label_key: a string that user specify for the label. The DataAnalyzer will look it up in the
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
output_path: str = "./datastats.yaml",
average: bool = True,
do_ccp: bool = False,
device: str | torch.device = "cpu",
device: str | torch.device = "cuda",
worker: int = 4,
image_key: str = "image",
label_key: str | None = "label",
Expand Down Expand Up @@ -170,8 +170,9 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool:

def get_all_case_stats(self, key="training", transform_list=None):
"""
Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and
call get_case_stats to generate stats. Then get_case_summary is called to combine results.
Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal
_get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case.
After all case stats are generated, SegSummarizer is called to combine results.

Args:
key: dataset key
Expand All @@ -197,6 +198,84 @@ def get_all_case_stats(self, key="training", transform_list=None):
dictionary will include .nan/.inf in the statistics.

"""
result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}
result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}
if self.device.type == "cpu":
nprocs = 1
logger.info("Using CPU for data analyzing!")
else:
nprocs = torch.cuda.device_count()
logger.info(f"Found {nprocs} GPUs for data analyzing!")
if nprocs > 1:
set_start_method("forkserver", force=True)
with Manager() as manager:
manager_list = manager.list()
processes = []
for rank in range(nprocs):
p = Process(target=self._get_all_case_stats, args=(rank, nprocs, manager_list, key, transform_list))
processes.append(p)
for p in processes:
p.start()
for p in processes:
p.join()
# merge DataStatsKeys.BY_CASE
for _ in manager_list:
result_bycase[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE])
else:
result_bycase = self._get_all_case_stats(0, 1, None, key, transform_list)

summarizer = SegSummarizer(
self.image_key,
self.label_key,
average=self.average,
do_ccp=self.do_ccp,
hist_bins=self.hist_bins,
hist_range=self.hist_range,
histogram_only=self.histogram_only,
)
n_cases = len(result_bycase[DataStatsKeys.BY_CASE])
result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE]))
result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases
result[DataStatsKeys.BY_CASE] = [None] * n_cases
result_bycase[DataStatsKeys.SUMMARY] = result[DataStatsKeys.SUMMARY]
if not self._check_data_uniformity([ImageStatsKeys.SPACING], result):
logger.info("Data spacing is not completely uniform. MONAI transforms may provide unexpected result")
if self.output_path:
ConfigParser.export_config_file(
result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False
)
ConfigParser.export_config_file(
result_bycase,
self.output_path.replace(".yaml", "_by_case.yaml"),
fmt=self.fmt,
default_flow_style=None,
sort_keys=False,
)
# release memory
if self.device.type == "cuda":
# release unreferenced tensors to mitigate OOM
# limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237
torch.cuda.empty_cache()
result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE]
return result

def _get_all_case_stats(
self,
rank: int = 0,
world_size: int = 1,
manager_list: list | None = None,
key: str = "training",
transform_list: list | None = None,
) -> Any:
"""
Get all case stats from a partitioned datalist. The function can only be called internally by get_all_case_stats.
Args:
rank: GPU process rank, 0 for CPU process
world_size: total number of GPUs, 1 for CPU process
manager_list: multiprocessing manager list object, if using multi-GPU.
key: dataset key
transform_list: option list of transforms before SegSummarizer
"""
summarizer = SegSummarizer(
self.image_key,
self.label_key,
Expand Down Expand Up @@ -224,8 +303,11 @@ def get_all_case_stats(self, key="training", transform_list=None):
)

transform = Compose(transform_list)

files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key)
if world_size <= len(files):
files = partition_dataset(data=files, num_partitions=world_size)[rank]
else:
files = partition_dataset(data=files, num_partitions=len(files))[rank] if rank < len(files) else []
dataset = Dataset(data=files, transform=transform)
dataloader = DataLoader(
dataset,
Expand All @@ -235,22 +317,34 @@ def get_all_case_stats(self, key="training", transform_list=None):
collate_fn=no_collation,
pin_memory=self.device.type == "cuda",
)
result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}
result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}

device = self.device if self.device.type == "cpu" else torch.device("cuda", rank)
if not has_tqdm:
warnings.warn("tqdm is not installed. not displaying the caching progress.")

for batch_data in tqdm(dataloader) if has_tqdm else dataloader:
for batch_data in tqdm(dataloader) if (has_tqdm and rank == 0) else dataloader:
batch_data = batch_data[0]
batch_data[self.image_key] = batch_data[self.image_key].to(self.device)

if self.label_key is not None:
label = batch_data[self.label_key]
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
batch_data[self.label_key] = label.to(self.device)

d = summarizer(batch_data)
try:
batch_data[self.image_key] = batch_data[self.image_key].to(device)
if self.label_key is not None:
label = batch_data[self.label_key]
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
batch_data[self.label_key] = label.to(device)
d = summarizer(batch_data)
except BaseException:
if "image_meta_dict" in batch_data.keys():
filename = batch_data["image_meta_dict"]["filename_or_obj"]
else:
filename = batch_data[self.image_key].meta["filename_or_obj"]
logger.info(f"Unable to process data {filename} on {device}.")
if self.device.type == "cuda":
logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.")
batch_data[self.image_key] = batch_data[self.image_key].to("cpu")
if self.label_key is not None:
label = batch_data[self.label_key]
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
batch_data[self.label_key] = label.to("cpu")
d = summarizer(batch_data)

stats_by_cases = {
DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
Expand All @@ -268,36 +362,7 @@ def get_all_case_stats(self, key="training", transform_list=None):
}
)
result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases)

n_cases = len(result_bycase[DataStatsKeys.BY_CASE])

result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE]))
result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases
result[DataStatsKeys.BY_CASE] = [None] * n_cases

if not self._check_data_uniformity([ImageStatsKeys.SPACING], result):
print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result")

if self.output_path:
# saving summary and by_case as 2 files, to minimize loading time when only the summary is necessary
ConfigParser.export_config_file(
result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False
)
ConfigParser.export_config_file(
result_bycase,
self.output_path.replace(".yaml", "_by_case.yaml"),
fmt=self.fmt,
default_flow_style=None,
sort_keys=False,
)

# release memory
d = None
if self.device.type == "cuda":
# release unreferenced tensors to mitigate OOM
# limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237
torch.cuda.empty_cache()

# return combined
result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE]
return result
if manager_list is None:
return result_bycase
else:
manager_list.append(result_bycase)