diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 2a04c9bb50..2c485f03eb 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -161,6 +161,8 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool: """ + if DataStatsKeys.SUMMARY not in result or DataStatsKeys.IMAGE_STATS not in result[DataStatsKeys.SUMMARY]: + return True constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys] for prop in constant_props: if "stdev" in prop and np.any(prop["stdev"]): @@ -358,10 +360,11 @@ def _get_all_case_stats( stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH], - DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS], } + if not self.histogram_only: + stats_by_cases[DataStatsKeys.IMAGE_STATS] = d[DataStatsKeys.IMAGE_STATS] if self.hist_bins != 0: - stats_by_cases.update({DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM]}) + stats_by_cases[DataStatsKeys.IMAGE_HISTOGRAM] = d[DataStatsKeys.IMAGE_HISTOGRAM] if self.label_key is not None: stats_by_cases.update( diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index d0da03d4c6..654999d439 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -198,7 +198,7 @@ class ImageStats(Analyzer): """ - def __init__(self, image_key: str, stats_name: str = "image_stats") -> None: + def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS) -> None: if not isinstance(image_key, str): raise ValueError("image_key input must be str") @@ -296,7 +296,7 @@ class FgImageStats(Analyzer): """ - def __init__(self, image_key: str, label_key: str, stats_name: str = "image_foreground_stats"): + def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.FG_IMAGE_STATS): self.image_key = image_key self.label_key = label_key @@ -378,7 +378,9 @@ class LabelStats(Analyzer): """ - def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: bool | None = True): + def __init__( + self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.LABEL_STATS, do_ccp: bool | None = True + ): self.image_key = image_key self.label_key = label_key self.do_ccp = do_ccp @@ -533,7 +535,7 @@ class ImageStatsSumm(Analyzer): """ - def __init__(self, stats_name: str = "image_stats", average: bool | None = True): + def __init__(self, stats_name: str = DataStatsKeys.IMAGE_STATS, average: bool | None = True): self.summary_average = average report_format = { ImageStatsKeys.SHAPE: None, @@ -623,7 +625,7 @@ class FgImageStatsSumm(Analyzer): """ - def __init__(self, stats_name: str = "image_foreground_stats", average: bool | None = True): + def __init__(self, stats_name: str = DataStatsKeys.FG_IMAGE_STATS, average: bool | None = True): self.summary_average = average report_format = {ImageStatsKeys.INTENSITY: None} @@ -687,7 +689,9 @@ class LabelStatsSumm(Analyzer): """ - def __init__(self, stats_name: str = "label_stats", average: bool | None = True, do_ccp: bool | None = True): + def __init__( + self, stats_name: str = DataStatsKeys.LABEL_STATS, average: bool | None = True, do_ccp: bool | None = True + ): self.summary_average = average self.do_ccp = do_ccp diff --git a/monai/auto3dseg/seg_summarizer.py b/monai/auto3dseg/seg_summarizer.py index d38ad582ac..14a10635df 100644 --- a/monai/auto3dseg/seg_summarizer.py +++ b/monai/auto3dseg/seg_summarizer.py @@ -100,9 +100,9 @@ def __init__( self.summary_analyzers: list[Any] = [] super().__init__() + self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None) + self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None) if not self.histogram_only: - self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None) - self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None) self.add_analyzer(ImageStats(image_key), ImageStatsSumm(average=average)) if label_key is None: diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py index 272fb52f1a..5964ddd6e9 100644 --- a/tests/test_auto3dseg.py +++ b/tests/test_auto3dseg.py @@ -190,6 +190,21 @@ def test_data_analyzer_cpu(self, input_params): assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"]) + def test_data_analyzer_histogram(self): + create_sim_data( + self.dataroot_dir, sim_datalist, [32] * 3, image_only=True, rad_max=8, rad_min=1, num_seg_classes=1 + ) + analyser = DataAnalyzer( + self.datalist_file, + self.dataroot_dir, + output_path=self.datastat_file, + label_key=None, + device=device, + histogram_only=True, + ) + datastat = analyser.get_all_case_stats() + assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"]) + @parameterized.expand(SIM_GPU_TEST_CASES) @skip_if_no_cuda def test_data_analyzer_gpu(self, input_params):