diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6167e83e47..11fb5edd28 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -17,34 +17,43 @@ @overload -def _compute_path(base_dir: str, element: str) -> str: +def _compute_path(base_dir: str, element: str, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: str, element: List[str]) -> List[str]: +def _compute_path(base_dir: str, element: List[str], check_path: bool = False) -> List[str]: ... -def _compute_path(base_dir, element): +def _compute_path(base_dir, element, check_path=False): """ Args: base_dir: the base directory of the dataset. element: file path(s) to append to directory. + check_path: if `True`, only compute when the result is an existing path. Raises: TypeError: When ``element`` contains a non ``str``. TypeError: When ``element`` type is not in ``Union[list, str]``. """ + + def _join_path(base_dir: str, item: str): + result = os.path.normpath(os.path.join(base_dir, item)) + if check_path and not os.path.exists(result): + # if not an existing path, don't join with base dir + return item + return result + if isinstance(element, str): - return os.path.normpath(os.path.join(base_dir, element)) + return _join_path(base_dir, element) if isinstance(element, list): for e in element: if not isinstance(e, str): - raise TypeError(f"Every file path in element must be a str but got {type(element).__name__}.") - return [os.path.normpath(os.path.join(base_dir, e)) for e in element] - raise TypeError(f"element must be one of (str, list) but is {type(element).__name__}.") + return element + return [_join_path(base_dir, e) for e in element] + return element def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]: @@ -63,9 +72,12 @@ def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> Li raise TypeError(f"Every item in items must be a dict but got {type(item).__name__}.") for k, v in item.items(): if k == "image": - item[k] = _compute_path(base_dir, v) + item[k] = _compute_path(base_dir, v, check_path=False) elif is_segmentation and k == "label": - item[k] = _compute_path(base_dir, v) + item[k] = _compute_path(base_dir, v, check_path=False) + else: + # for other items, auto detect whether it's a valid path + item[k] = _compute_path(base_dir, v, check_path=True) return items diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index f57cbd1116..c437cd055b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -656,10 +656,6 @@ def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_ name: the name corresponding to the key to store the concatenated data. dim: on which dimension to concatenate the items, default is 0. allow_missing_keys: don't raise exception if key is missing. - - Raises: - ValueError: When insufficient keys are given (``len(self.keys) < 2``). - """ super().__init__(keys, allow_missing_keys) self.name = name diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index 90b9d3ab03..fe7ff6f8a2 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -96,6 +96,31 @@ def test_seg_no_labels(self): result = load_decathlon_datalist(file_path, True, "test", tempdir) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz")) + def test_additional_items(self): + with tempfile.TemporaryDirectory() as tempdir: + with open(os.path.join(tempdir, "mask31.txt"), "w") as f: + f.write("spleen31 mask") + + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz", "mask": "spleen mask"}, + {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz", "mask": "mask31.txt"}, + ], + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + result = load_decathlon_datalist(file_path, True, "training", tempdir) + self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(result[1]["mask"], os.path.join(tempdir, "mask31.txt")) + self.assertEqual(result[0]["mask"], "spleen mask") + if __name__ == "__main__": unittest.main()