diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 3bb67bdbe2..950baed8f2 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -17,6 +17,7 @@ import numpy as np import torch +from torch.multiprocessing import Manager, Process, set_start_method from monai.apps.auto3dseg.transforms import EnsureSameShaped from monai.apps.utils import get_logger @@ -24,7 +25,7 @@ from monai.auto3dseg.utils import datafold_read from monai.bundle import config_parser from monai.bundle.config_parser import ConfigParser -from monai.data import DataLoader, Dataset +from monai.data import DataLoader, Dataset, partition_dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd from monai.utils import StrEnum, min_version, optional_import @@ -61,8 +62,7 @@ class DataAnalyzer: average: whether to average the statistical value across different image modalities. do_ccp: apply the connected component algorithm to process the labels/images device: a string specifying hardware (CUDA/CPU) utilized for the operations. - worker: number of workers to use for parallel processing. If device is cuda/GPU, worker has - to be 0. + worker: number of workers to use for loading datasets in each GPU/CPU sub-process. image_key: a string that user specify for the image. The DataAnalyzer will look it up in the datalist to locate the image files of the dataset. label_key: a string that user specify for the label. The DataAnalyzer will look it up in the @@ -118,7 +118,7 @@ def __init__( output_path: str = "./datastats.yaml", average: bool = True, do_ccp: bool = False, - device: str | torch.device = "cpu", + device: str | torch.device = "cuda", worker: int = 4, image_key: str = "image", label_key: str | None = "label", @@ -170,8 +170,9 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: def get_all_case_stats(self, key="training", transform_list=None): """ - Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and - call get_case_stats to generate stats. Then get_case_summary is called to combine results. + Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal + _get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case. + After all case stats are generated, SegSummarizer is called to combine results. Args: key: dataset key @@ -197,6 +198,84 @@ def get_all_case_stats(self, key="training", transform_list=None): dictionary will include .nan/.inf in the statistics. """ + result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + if self.device.type == "cpu": + nprocs = 1 + logger.info("Using CPU for data analyzing!") + else: + nprocs = torch.cuda.device_count() + logger.info(f"Found {nprocs} GPUs for data analyzing!") + if nprocs > 1: + set_start_method("forkserver", force=True) + with Manager() as manager: + manager_list = manager.list() + processes = [] + for rank in range(nprocs): + p = Process(target=self._get_all_case_stats, args=(rank, nprocs, manager_list, key, transform_list)) + processes.append(p) + for p in processes: + p.start() + for p in processes: + p.join() + # merge DataStatsKeys.BY_CASE + for _ in manager_list: + result_bycase[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE]) + else: + result_bycase = self._get_all_case_stats(0, 1, None, key, transform_list) + + summarizer = SegSummarizer( + self.image_key, + self.label_key, + average=self.average, + do_ccp=self.do_ccp, + hist_bins=self.hist_bins, + hist_range=self.hist_range, + histogram_only=self.histogram_only, + ) + n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) + result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) + result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases + result[DataStatsKeys.BY_CASE] = [None] * n_cases + result_bycase[DataStatsKeys.SUMMARY] = result[DataStatsKeys.SUMMARY] + if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): + logger.info("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") + if self.output_path: + ConfigParser.export_config_file( + result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False + ) + ConfigParser.export_config_file( + result_bycase, + self.output_path.replace(".yaml", "_by_case.yaml"), + fmt=self.fmt, + default_flow_style=None, + sort_keys=False, + ) + # release memory + if self.device.type == "cuda": + # release unreferenced tensors to mitigate OOM + # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 + torch.cuda.empty_cache() + result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] + return result + + def _get_all_case_stats( + self, + rank: int = 0, + world_size: int = 1, + manager_list: list | None = None, + key: str = "training", + transform_list: list | None = None, + ) -> Any: + """ + Get all case stats from a partitioned datalist. The function can only be called internally by get_all_case_stats. + Args: + rank: GPU process rank, 0 for CPU process + world_size: total number of GPUs, 1 for CPU process + manager_list: multiprocessing manager list object, if using multi-GPU. + key: dataset key + transform_list: option list of transforms before SegSummarizer + """ summarizer = SegSummarizer( self.image_key, self.label_key, @@ -224,8 +303,11 @@ def get_all_case_stats(self, key="training", transform_list=None): ) transform = Compose(transform_list) - files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key) + if world_size <= len(files): + files = partition_dataset(data=files, num_partitions=world_size)[rank] + else: + files = partition_dataset(data=files, num_partitions=len(files))[rank] if rank < len(files) else [] dataset = Dataset(data=files, transform=transform) dataloader = DataLoader( dataset, @@ -235,22 +317,34 @@ def get_all_case_stats(self, key="training", transform_list=None): collate_fn=no_collation, pin_memory=self.device.type == "cuda", ) - result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} - + device = self.device if self.device.type == "cpu" else torch.device("cuda", rank) if not has_tqdm: warnings.warn("tqdm is not installed. not displaying the caching progress.") - for batch_data in tqdm(dataloader) if has_tqdm else dataloader: + for batch_data in tqdm(dataloader) if (has_tqdm and rank == 0) else dataloader: batch_data = batch_data[0] - batch_data[self.image_key] = batch_data[self.image_key].to(self.device) - - if self.label_key is not None: - label = batch_data[self.label_key] - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to(self.device) - - d = summarizer(batch_data) + try: + batch_data[self.image_key] = batch_data[self.image_key].to(device) + if self.label_key is not None: + label = batch_data[self.label_key] + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to(device) + d = summarizer(batch_data) + except BaseException: + if "image_meta_dict" in batch_data.keys(): + filename = batch_data["image_meta_dict"]["filename_or_obj"] + else: + filename = batch_data[self.image_key].meta["filename_or_obj"] + logger.info(f"Unable to process data {filename} on {device}.") + if self.device.type == "cuda": + logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.") + batch_data[self.image_key] = batch_data[self.image_key].to("cpu") + if self.label_key is not None: + label = batch_data[self.label_key] + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to("cpu") + d = summarizer(batch_data) stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], @@ -268,36 +362,7 @@ def get_all_case_stats(self, key="training", transform_list=None): } ) result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases) - - n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) - - result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) - result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases - result[DataStatsKeys.BY_CASE] = [None] * n_cases - - if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): - print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") - - if self.output_path: - # saving summary and by_case as 2 files, to minimize loading time when only the summary is necessary - ConfigParser.export_config_file( - result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False - ) - ConfigParser.export_config_file( - result_bycase, - self.output_path.replace(".yaml", "_by_case.yaml"), - fmt=self.fmt, - default_flow_style=None, - sort_keys=False, - ) - - # release memory - d = None - if self.device.type == "cuda": - # release unreferenced tensors to mitigate OOM - # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 - torch.cuda.empty_cache() - - # return combined - result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] - return result + if manager_list is None: + return result_bycase + else: + manager_list.append(result_bycase)