From 46a77253ef2b68e12ef089ef9e102b10a973153a Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Fri, 17 Mar 2023 11:26:31 -0400 Subject: [PATCH 1/8] Add multi-gpu data analyzer --- monai/apps/auto3dseg/data_analyzer.py | 116 +++++++++++++++----------- 1 file changed, 67 insertions(+), 49 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 3bb67bdbe2..c68b0f620d 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -17,6 +17,7 @@ import numpy as np import torch +from torch.multiprocessing import Process, set_start_method, Manager from monai.apps.auto3dseg.transforms import EnsureSameShaped from monai.apps.utils import get_logger @@ -24,12 +25,15 @@ from monai.auto3dseg.utils import datafold_read from monai.bundle import config_parser from monai.bundle.config_parser import ConfigParser -from monai.data import DataLoader, Dataset +from monai.data import DataLoader, Dataset, partition_dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd from monai.utils import StrEnum, min_version, optional_import from monai.utils.enums import DataStatsKeys, ImageStatsKeys +import warnings +# remove the warning "warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}")" +warnings.filterwarnings("ignore", category=UserWarning, module='monai') def strenum_representer(dumper, data): return dumper.represent_scalar("tag:yaml.org,2002:str", data.value) @@ -115,10 +119,10 @@ def __init__( self, datalist: str | dict, dataroot: str = "", - output_path: str = "./datastats.yaml", + output_path: str = "./data_stats.yaml", average: bool = True, do_ccp: bool = False, - device: str | torch.device = "cpu", + device: str | torch.device = "cuda", worker: int = 4, image_key: str = "image", label_key: str | None = "label", @@ -169,6 +173,58 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: return True def get_all_case_stats(self, key="training", transform_list=None): + """ Wrapper for the internal _get_all_case_stats to perform multi-gpu processing + """ + if self.device.type == 'cpu': + nprocs = 1 + print(f'Using CPU for data analyzing!') + else: + nprocs = torch.cuda.device_count() + print(f'Found {nprocs} GPUs for data analyzing!') + set_start_method('forkserver', force=True) + with Manager() as manager: + manager_list = manager.list() + processes = [] + for rank in range(nprocs): + p = Process(target=self._get_all_case_stats, args=(rank, nprocs, manager_list, key, transform_list)) + processes.append(p) + for p in processes: + p.start() + for p in processes: + p.join() + # merge DataStatsKeys.BY_CASE + result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + for _ in manager_list: + result[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE]) + summarizer = SegSummarizer( + self.image_key, + self.label_key, + average=self.average, + do_ccp=self.do_ccp, + hist_bins=self.hist_bins, + hist_range=self.hist_range, + histogram_only=self.histogram_only, + ) + result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result[DataStatsKeys.BY_CASE])) + + if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): + print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") + + if self.output_path: + ConfigParser.export_config_file( + result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False + ) + + # release memory + d = None + if self.device.type == "cuda": + # release unreferenced tensors to mitigate OOM + # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 + torch.cuda.empty_cache() + + return result + + def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Manager.list=[], key="training", transform_list=None): """ Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and call get_case_stats to generate stats. Then get_case_summary is called to combine results. @@ -224,31 +280,23 @@ def get_all_case_stats(self, key="training", transform_list=None): ) transform = Compose(transform_list) - files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key) + files = partition_dataset(data=files, num_partitions=world_size)[rank] dataset = Dataset(data=files, transform=transform) - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - num_workers=self.worker, - collate_fn=no_collation, - pin_memory=self.device.type == "cuda", - ) + dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.worker, collate_fn=no_collation) result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} - result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} - + device = self.device if self.device.type == 'cpu' else torch.device(f'cuda',rank) if not has_tqdm: warnings.warn("tqdm is not installed. not displaying the caching progress.") - for batch_data in tqdm(dataloader) if has_tqdm else dataloader: + for batch_data in tqdm(dataloader) if (has_tqdm and rank==0) else dataloader: batch_data = batch_data[0] - batch_data[self.image_key] = batch_data[self.image_key].to(self.device) + batch_data[self.image_key] = batch_data[self.image_key].to(device) if self.label_key is not None: label = batch_data[self.label_key] label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to(self.device) + batch_data[self.label_key] = label.to(device) d = summarizer(batch_data) @@ -267,37 +315,7 @@ def get_all_case_stats(self, key="training", transform_list=None): DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS], } ) - result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases) - - n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) - - result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) - result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases - result[DataStatsKeys.BY_CASE] = [None] * n_cases - - if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): - print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") - - if self.output_path: - # saving summary and by_case as 2 files, to minimize loading time when only the summary is necessary - ConfigParser.export_config_file( - result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False - ) - ConfigParser.export_config_file( - result_bycase, - self.output_path.replace(".yaml", "_by_case.yaml"), - fmt=self.fmt, - default_flow_style=None, - sort_keys=False, - ) + result[DataStatsKeys.BY_CASE].append(stats_by_cases) + manager_list.append(result) - # release memory - d = None - if self.device.type == "cuda": - # release unreferenced tensors to mitigate OOM - # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 - torch.cuda.empty_cache() - # return combined - result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] - return result From 6f4a0ccd7d1d0be7712fb6595fc7f176ccceab49 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Mar 2023 17:58:19 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/auto3dseg/data_analyzer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index c68b0f620d..1729face53 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -181,7 +181,7 @@ def get_all_case_stats(self, key="training", transform_list=None): else: nprocs = torch.cuda.device_count() print(f'Found {nprocs} GPUs for data analyzing!') - set_start_method('forkserver', force=True) + set_start_method('forkserver', force=True) with Manager() as manager: manager_list = manager.list() processes = [] @@ -317,5 +317,3 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana ) result[DataStatsKeys.BY_CASE].append(stats_by_cases) manager_list.append(result) - - From f7489270dd53df570b0fb4c2a2000a4e673462e7 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 21 Mar 2023 22:11:26 -0400 Subject: [PATCH 3/8] Update multi-gpu data analyzer Signed-off-by: heyufan1995 --- monai/apps/auto3dseg/data_analyzer.py | 131 ++++++++++++++++---------- 1 file changed, 81 insertions(+), 50 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 1729face53..cb1d2acf6a 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -65,8 +65,7 @@ class DataAnalyzer: average: whether to average the statistical value across different image modalities. do_ccp: apply the connected component algorithm to process the labels/images device: a string specifying hardware (CUDA/CPU) utilized for the operations. - worker: number of workers to use for parallel processing. If device is cuda/GPU, worker has - to be 0. + worker: number of workers to use for loading datasets in each GPU/CPU sub-process. image_key: a string that user specify for the image. The DataAnalyzer will look it up in the datalist to locate the image files of the dataset. label_key: a string that user specify for the label. The DataAnalyzer will look it up in the @@ -119,7 +118,7 @@ def __init__( self, datalist: str | dict, dataroot: str = "", - output_path: str = "./data_stats.yaml", + output_path: str = "./datastats.yaml", average: bool = True, do_ccp: bool = False, device: str | torch.device = "cuda", @@ -173,14 +172,41 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: return True def get_all_case_stats(self, key="training", transform_list=None): - """ Wrapper for the internal _get_all_case_stats to perform multi-gpu processing + """ + Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal + _get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case. + After all case stats are generated, SegSummarizer is called to combine results. + + Args: + key: dataset key + transform_list: option list of transforms before SegSummarizer + + Returns: + A data statistics dictionary containing + "stats_summary" (summary statistics of the entire datasets). Within stats_summary + there are "image_stats" (summarizing info of shape, channel, spacing, and etc + using operations_summary), "image_foreground_stats" (info of the intensity for the + non-zero labeled voxels), and "label_stats" (info of the labels, pixel percentage, + image_intensity, and each individual label in a list) + "stats_by_cases" (List type value. Each element of the list is statistics of + an image-label info. Within each element, there are: "image" (value is the + path to an image), "label" (value is the path to the corresponding label), "image_stats" + (summarizing info of shape, channel, spacing, and etc using operations), + "image_foreground_stats" (similar to the previous one but one foreground image), and + "label_stats" (stats of the individual labels ) + + Notes: + Since the backend of the statistics computation are torch/numpy, nan/inf value + may be generated and carried over in the computation. In such cases, the output + dictionary will include .nan/.inf in the statistics. + """ if self.device.type == 'cpu': nprocs = 1 - print(f'Using CPU for data analyzing!') + logger.info(f'Using CPU for data analyzing!') else: nprocs = torch.cuda.device_count() - print(f'Found {nprocs} GPUs for data analyzing!') + logger.info(f'Found {nprocs} GPUs for data analyzing!') set_start_method('forkserver', force=True) with Manager() as manager: manager_list = manager.list() @@ -194,8 +220,9 @@ def get_all_case_stats(self, key="training", transform_list=None): p.join() # merge DataStatsKeys.BY_CASE result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} for _ in manager_list: - result[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE]) + result_bycase[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE]) summarizer = SegSummarizer( self.image_key, self.label_key, @@ -205,53 +232,39 @@ def get_all_case_stats(self, key="training", transform_list=None): hist_range=self.hist_range, histogram_only=self.histogram_only, ) - result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result[DataStatsKeys.BY_CASE])) - + n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) + result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) + result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases + result_bycase[DataStatsKeys.SUMMARY] = result[DataStatsKeys.SUMMARY] if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): - print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") + logger.info("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") if self.output_path: ConfigParser.export_config_file( result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False ) - + ConfigParser.export_config_file( + result_bycase, + self.output_path.replace(".yaml", "_by_case.yaml"), + fmt=self.fmt, + default_flow_style=None, + sort_keys=False, + ) # release memory d = None if self.device.type == "cuda": # release unreferenced tensors to mitigate OOM # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 torch.cuda.empty_cache() - + result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] return result def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Manager.list=[], key="training", transform_list=None): """ - Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and - call get_case_stats to generate stats. Then get_case_summary is called to combine results. - + Get all case stats from a partitioned datalist. The function can only be called internally by get_all_case_stats. Args: key: dataset key transform_list: option list of transforms before SegSummarizer - - Returns: - A data statistics dictionary containing - "stats_summary" (summary statistics of the entire datasets). Within stats_summary - there are "image_stats" (summarizing info of shape, channel, spacing, and etc - using operations_summary), "image_foreground_stats" (info of the intensity for the - non-zero labeled voxels), and "label_stats" (info of the labels, pixel percentage, - image_intensity, and each individual label in a list) - "stats_by_cases" (List type value. Each element of the list is statistics of - an image-label info. Within each element, there are: "image" (value is the - path to an image), "label" (value is the path to the corresponding label), "image_stats" - (summarizing info of shape, channel, spacing, and etc using operations), - "image_foreground_stats" (similar to the previous one but one foreground image), and - "label_stats" (stats of the individual labels ) - - Notes: - Since the backend of the statistics computation are torch/numpy, nan/inf value - may be generated and carried over in the computation. In such cases, the output - dictionary will include .nan/.inf in the statistics. - """ summarizer = SegSummarizer( self.image_key, @@ -265,7 +278,7 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana keys = list(filter(None, [self.image_key, self.label_key])) if transform_list is None: transform_list = [ - LoadImaged(keys=keys, ensure_channel_first=True, image_only=True), + LoadImaged(keys=keys, ensure_channel_first=True, image_only=False), EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float), Orientationd(keys=keys, axcodes="RAS"), ] @@ -281,25 +294,43 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana transform = Compose(transform_list) files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key) - files = partition_dataset(data=files, num_partitions=world_size)[rank] + if world_size <= len(files): + files = partition_dataset(data=files, num_partitions=world_size)[rank] + else: + files = partition_dataset(data=files, num_partitions=len(files))[rank] if rank < len(files) else [] dataset = Dataset(data=files, transform=transform) - dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.worker, collate_fn=no_collation) - result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=self.worker, + collate_fn=no_collation, + pin_memory=self.device.type == "cuda", + ) + result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} device = self.device if self.device.type == 'cpu' else torch.device(f'cuda',rank) if not has_tqdm: warnings.warn("tqdm is not installed. not displaying the caching progress.") for batch_data in tqdm(dataloader) if (has_tqdm and rank==0) else dataloader: batch_data = batch_data[0] - batch_data[self.image_key] = batch_data[self.image_key].to(device) - - if self.label_key is not None: - label = batch_data[self.label_key] - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to(device) - - d = summarizer(batch_data) - + try: + batch_data[self.image_key] = batch_data[self.image_key].to(device) + if self.label_key is not None: + label = batch_data[self.label_key] + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to(device) + + d = summarizer(batch_data) + except: + logger.info(f"Unable to process data {batch_data['image_meta_dict']['filename_or_obj']} on {device}.") + if self.device.type == 'cuda': + logger.info(f"Data analysis using CPU.") + batch_data[self.image_key] = batch_data[self.image_key].to('cpu') + if self.label_key is not None: + batch_data[self.label_key] = label.to('cpu') + d = summarizer(batch_data) + stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH], @@ -315,5 +346,5 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS], } ) - result[DataStatsKeys.BY_CASE].append(stats_by_cases) - manager_list.append(result) + result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases) + manager_list.append(result_bycase) From 3cacadd8f6962440fbc08173d2b4d9269791ceb2 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Mon, 20 Mar 2023 16:32:44 -0400 Subject: [PATCH 4/8] Enable swinuentr v2 --- monai/networks/nets/swin_unetr.py | 38 +++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 9f8204968f..5ea5fd57a2 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -65,6 +65,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", + use_v2=True ) -> None: """ Args: @@ -84,6 +85,7 @@ def __init__( downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). + use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage. Examples:: @@ -142,6 +144,7 @@ def __init__( use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, + use_v2=use_v2 ) self.encoder1 = UnetrBasicBlock( @@ -921,6 +924,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", + use_v2=False ) -> None: """ Args: @@ -942,6 +946,7 @@ def __init__( downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). + use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage. """ super().__init__() @@ -959,10 +964,16 @@ def __init__( ) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.use_v2 = use_v2 self.layers1 = nn.ModuleList() self.layers2 = nn.ModuleList() self.layers3 = nn.ModuleList() self.layers4 = nn.ModuleList() + if self.use_v2: + self.layers1c = nn.ModuleList() + self.layers2c = nn.ModuleList() + self.layers3c = nn.ModuleList() + self.layers4c = nn.ModuleList() down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample for i_layer in range(self.num_layers): layer = BasicLayer( @@ -987,6 +998,25 @@ def __init__( self.layers3.append(layer) elif i_layer == 3: self.layers4.append(layer) + if self.use_v2: + layerc = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim * 2**i_layer, + out_channels=embed_dim * 2**i_layer, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + if i_layer == 0: + self.layers1c.append(layerc) + elif i_layer == 1: + self.layers2c.append(layerc) + elif i_layer == 2: + self.layers3c.append(layerc) + elif i_layer == 3: + self.layers4c.append(layerc) + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) def proj_out(self, x, normalize=False): @@ -1008,12 +1038,20 @@ def forward(self, x, normalize=True): x0 = self.patch_embed(x) x0 = self.pos_drop(x0) x0_out = self.proj_out(x0, normalize) + if self.use_v2: + x0 = self.layers1c[0](x0.contiguous()) x1 = self.layers1[0](x0.contiguous()) x1_out = self.proj_out(x1, normalize) + if self.use_v2: + x1 = self.layers2c[0](x1.contiguous()) x2 = self.layers2[0](x1.contiguous()) x2_out = self.proj_out(x2, normalize) + if self.use_v2: + x2 = self.layers3c[0](x2.contiguous()) x3 = self.layers3[0](x2.contiguous()) x3_out = self.proj_out(x3, normalize) + if self.use_v2: + x3 = self.layers4c[0](x3.contiguous()) x4 = self.layers4[0](x3.contiguous()) x4_out = self.proj_out(x4, normalize) return [x0_out, x1_out, x2_out, x3_out, x4_out] From f4b11bcffade2ab9074b8a5876f2ca607e2639fe Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Mon, 20 Mar 2023 19:08:18 -0400 Subject: [PATCH 5/8] Change use_v2 to false by default --- monai/networks/nets/swin_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 5ea5fd57a2..58d0e58c83 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -65,7 +65,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", - use_v2=True + use_v2=False ) -> None: """ Args: From 042b0fc8c20d190afe78672f8aeed23ced9fafdb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Mar 2023 14:41:46 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/auto3dseg/data_analyzer.py | 6 +++--- monai/networks/nets/swin_unetr.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index cb1d2acf6a..da156998c3 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -173,8 +173,8 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: def get_all_case_stats(self, key="training", transform_list=None): """ - Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal - _get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case. + Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal + _get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case. After all case stats are generated, SegSummarizer is called to combine results. Args: @@ -330,7 +330,7 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana if self.label_key is not None: batch_data[self.label_key] = label.to('cpu') d = summarizer(batch_data) - + stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH], diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 58d0e58c83..c1d4331a74 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -973,7 +973,7 @@ def __init__( self.layers1c = nn.ModuleList() self.layers2c = nn.ModuleList() self.layers3c = nn.ModuleList() - self.layers4c = nn.ModuleList() + self.layers4c = nn.ModuleList() down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample for i_layer in range(self.num_layers): layer = BasicLayer( @@ -1015,7 +1015,7 @@ def __init__( elif i_layer == 2: self.layers3c.append(layerc) elif i_layer == 3: - self.layers4c.append(layerc) + self.layers4c.append(layerc) self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) From dd4bf6202ee50b7f3e23b0799479eba31910cfbd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 22 Mar 2023 14:43:43 +0000 Subject: [PATCH 7/8] rebase Signed-off-by: Wenqi Li --- monai/apps/auto3dseg/data_analyzer.py | 151 +++++++++----------------- 1 file changed, 52 insertions(+), 99 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index da156998c3..3bb67bdbe2 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -17,7 +17,6 @@ import numpy as np import torch -from torch.multiprocessing import Process, set_start_method, Manager from monai.apps.auto3dseg.transforms import EnsureSameShaped from monai.apps.utils import get_logger @@ -25,15 +24,12 @@ from monai.auto3dseg.utils import datafold_read from monai.bundle import config_parser from monai.bundle.config_parser import ConfigParser -from monai.data import DataLoader, Dataset, partition_dataset +from monai.data import DataLoader, Dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd from monai.utils import StrEnum, min_version, optional_import from monai.utils.enums import DataStatsKeys, ImageStatsKeys -import warnings -# remove the warning "warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}")" -warnings.filterwarnings("ignore", category=UserWarning, module='monai') def strenum_representer(dumper, data): return dumper.represent_scalar("tag:yaml.org,2002:str", data.value) @@ -65,7 +61,8 @@ class DataAnalyzer: average: whether to average the statistical value across different image modalities. do_ccp: apply the connected component algorithm to process the labels/images device: a string specifying hardware (CUDA/CPU) utilized for the operations. - worker: number of workers to use for loading datasets in each GPU/CPU sub-process. + worker: number of workers to use for parallel processing. If device is cuda/GPU, worker has + to be 0. image_key: a string that user specify for the image. The DataAnalyzer will look it up in the datalist to locate the image files of the dataset. label_key: a string that user specify for the label. The DataAnalyzer will look it up in the @@ -121,7 +118,7 @@ def __init__( output_path: str = "./datastats.yaml", average: bool = True, do_ccp: bool = False, - device: str | torch.device = "cuda", + device: str | torch.device = "cpu", worker: int = 4, image_key: str = "image", label_key: str | None = "label", @@ -173,9 +170,8 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: def get_all_case_stats(self, key="training", transform_list=None): """ - Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal - _get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case. - After all case stats are generated, SegSummarizer is called to combine results. + Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and + call get_case_stats to generate stats. Then get_case_summary is called to combine results. Args: key: dataset key @@ -201,71 +197,6 @@ def get_all_case_stats(self, key="training", transform_list=None): dictionary will include .nan/.inf in the statistics. """ - if self.device.type == 'cpu': - nprocs = 1 - logger.info(f'Using CPU for data analyzing!') - else: - nprocs = torch.cuda.device_count() - logger.info(f'Found {nprocs} GPUs for data analyzing!') - set_start_method('forkserver', force=True) - with Manager() as manager: - manager_list = manager.list() - processes = [] - for rank in range(nprocs): - p = Process(target=self._get_all_case_stats, args=(rank, nprocs, manager_list, key, transform_list)) - processes.append(p) - for p in processes: - p.start() - for p in processes: - p.join() - # merge DataStatsKeys.BY_CASE - result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} - result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} - for _ in manager_list: - result_bycase[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE]) - summarizer = SegSummarizer( - self.image_key, - self.label_key, - average=self.average, - do_ccp=self.do_ccp, - hist_bins=self.hist_bins, - hist_range=self.hist_range, - histogram_only=self.histogram_only, - ) - n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) - result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) - result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases - result_bycase[DataStatsKeys.SUMMARY] = result[DataStatsKeys.SUMMARY] - if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): - logger.info("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") - - if self.output_path: - ConfigParser.export_config_file( - result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False - ) - ConfigParser.export_config_file( - result_bycase, - self.output_path.replace(".yaml", "_by_case.yaml"), - fmt=self.fmt, - default_flow_style=None, - sort_keys=False, - ) - # release memory - d = None - if self.device.type == "cuda": - # release unreferenced tensors to mitigate OOM - # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 - torch.cuda.empty_cache() - result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] - return result - - def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Manager.list=[], key="training", transform_list=None): - """ - Get all case stats from a partitioned datalist. The function can only be called internally by get_all_case_stats. - Args: - key: dataset key - transform_list: option list of transforms before SegSummarizer - """ summarizer = SegSummarizer( self.image_key, self.label_key, @@ -278,7 +209,7 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana keys = list(filter(None, [self.image_key, self.label_key])) if transform_list is None: transform_list = [ - LoadImaged(keys=keys, ensure_channel_first=True, image_only=False), + LoadImaged(keys=keys, ensure_channel_first=True, image_only=True), EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float), Orientationd(keys=keys, axcodes="RAS"), ] @@ -293,11 +224,8 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana ) transform = Compose(transform_list) + files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key) - if world_size <= len(files): - files = partition_dataset(data=files, num_partitions=world_size)[rank] - else: - files = partition_dataset(data=files, num_partitions=len(files))[rank] if rank < len(files) else [] dataset = Dataset(data=files, transform=transform) dataloader = DataLoader( dataset, @@ -307,29 +235,22 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana collate_fn=no_collation, pin_memory=self.device.type == "cuda", ) + result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []} - device = self.device if self.device.type == 'cpu' else torch.device(f'cuda',rank) + if not has_tqdm: warnings.warn("tqdm is not installed. not displaying the caching progress.") - for batch_data in tqdm(dataloader) if (has_tqdm and rank==0) else dataloader: + for batch_data in tqdm(dataloader) if has_tqdm else dataloader: batch_data = batch_data[0] - try: - batch_data[self.image_key] = batch_data[self.image_key].to(device) - if self.label_key is not None: - label = batch_data[self.label_key] - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to(device) - - d = summarizer(batch_data) - except: - logger.info(f"Unable to process data {batch_data['image_meta_dict']['filename_or_obj']} on {device}.") - if self.device.type == 'cuda': - logger.info(f"Data analysis using CPU.") - batch_data[self.image_key] = batch_data[self.image_key].to('cpu') - if self.label_key is not None: - batch_data[self.label_key] = label.to('cpu') - d = summarizer(batch_data) + batch_data[self.image_key] = batch_data[self.image_key].to(self.device) + + if self.label_key is not None: + label = batch_data[self.label_key] + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to(self.device) + + d = summarizer(batch_data) stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], @@ -347,4 +268,36 @@ def _get_all_case_stats(self, rank: int=0, world_size: int=1, manager_list: Mana } ) result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases) - manager_list.append(result_bycase) + + n_cases = len(result_bycase[DataStatsKeys.BY_CASE]) + + result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE])) + result[DataStatsKeys.SUMMARY]["n_cases"] = n_cases + result[DataStatsKeys.BY_CASE] = [None] * n_cases + + if not self._check_data_uniformity([ImageStatsKeys.SPACING], result): + print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result") + + if self.output_path: + # saving summary and by_case as 2 files, to minimize loading time when only the summary is necessary + ConfigParser.export_config_file( + result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False + ) + ConfigParser.export_config_file( + result_bycase, + self.output_path.replace(".yaml", "_by_case.yaml"), + fmt=self.fmt, + default_flow_style=None, + sort_keys=False, + ) + + # release memory + d = None + if self.device.type == "cuda": + # release unreferenced tensors to mitigate OOM + # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237 + torch.cuda.empty_cache() + + # return combined + result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE] + return result From 2c4fc6e715710670a481c268fe03a5ad92df8e69 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 22 Mar 2023 16:53:18 +0000 Subject: [PATCH 8/8] [MONAI] code formatting Signed-off-by: monai-bot --- monai/networks/nets/swin_unetr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index c1d4331a74..59fdb41815 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -65,7 +65,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", - use_v2=False + use_v2=False, ) -> None: """ Args: @@ -144,7 +144,7 @@ def __init__( use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, - use_v2=use_v2 + use_v2=use_v2, ) self.encoder1 = UnetrBasicBlock( @@ -924,7 +924,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", - use_v2=False + use_v2=False, ) -> None: """ Args: @@ -1005,7 +1005,7 @@ def __init__( out_channels=embed_dim * 2**i_layer, kernel_size=3, stride=1, - norm_name='instance', + norm_name="instance", res_block=True, ) if i_layer == 0: