Skip to content
30 changes: 21 additions & 9 deletions monai/data/decathlon_datalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,43 @@


@overload
def _compute_path(base_dir: str, element: str) -> str:
def _compute_path(base_dir: str, element: str, check_path: bool = False) -> str:
...


@overload
def _compute_path(base_dir: str, element: List[str]) -> List[str]:
def _compute_path(base_dir: str, element: List[str], check_path: bool = False) -> List[str]:
...


def _compute_path(base_dir, element):
def _compute_path(base_dir, element, check_path=False):
"""
Args:
base_dir: the base directory of the dataset.
element: file path(s) to append to directory.
check_path: if `True`, only compute when the result is an existing path.

Raises:
TypeError: When ``element`` contains a non ``str``.
TypeError: When ``element`` type is not in ``Union[list, str]``.

"""

def _join_path(base_dir: str, item: str):
result = os.path.normpath(os.path.join(base_dir, item))
if check_path and not os.path.exists(result):
# if not an existing path, don't join with base dir
return item
return result

if isinstance(element, str):
return os.path.normpath(os.path.join(base_dir, element))
return _join_path(base_dir, element)
if isinstance(element, list):
for e in element:
if not isinstance(e, str):
raise TypeError(f"Every file path in element must be a str but got {type(element).__name__}.")
return [os.path.normpath(os.path.join(base_dir, e)) for e in element]
raise TypeError(f"element must be one of (str, list) but is {type(element).__name__}.")
return element
return [_join_path(base_dir, e) for e in element]
return element


def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]:
Expand All @@ -63,9 +72,12 @@ def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> Li
raise TypeError(f"Every item in items must be a dict but got {type(item).__name__}.")
for k, v in item.items():
if k == "image":
item[k] = _compute_path(base_dir, v)
item[k] = _compute_path(base_dir, v, check_path=False)
elif is_segmentation and k == "label":
item[k] = _compute_path(base_dir, v)
item[k] = _compute_path(base_dir, v, check_path=False)
else:
# for other items, auto detect whether it's a valid path
item[k] = _compute_path(base_dir, v, check_path=True)
return items


Expand Down
4 changes: 0 additions & 4 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,6 @@ def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_
name: the name corresponding to the key to store the concatenated data.
dim: on which dimension to concatenate the items, default is 0.
allow_missing_keys: don't raise exception if key is missing.

Raises:
ValueError: When insufficient keys are given (``len(self.keys) < 2``).

"""
super().__init__(keys, allow_missing_keys)
self.name = name
Expand Down
25 changes: 25 additions & 0 deletions tests/test_load_decathlon_datalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,31 @@ def test_seg_no_labels(self):
result = load_decathlon_datalist(file_path, True, "test", tempdir)
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))

def test_additional_items(self):
with tempfile.TemporaryDirectory() as tempdir:
with open(os.path.join(tempdir, "mask31.txt"), "w") as f:
f.write("spleen31 mask")

test_data = {
"name": "Spleen",
"description": "Spleen Segmentation",
"labels": {"0": "background", "1": "spleen"},
"training": [
{"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz", "mask": "spleen mask"},
{"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz", "mask": "mask31.txt"},
],
"test": ["spleen_15.nii.gz", "spleen_23.nii.gz"],
}
json_str = json.dumps(test_data)
file_path = os.path.join(tempdir, "test_data.json")
with open(file_path, "w") as json_file:
json_file.write(json_str)
result = load_decathlon_datalist(file_path, True, "training", tempdir)
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
self.assertEqual(result[1]["mask"], os.path.join(tempdir, "mask31.txt"))
self.assertEqual(result[0]["mask"], "spleen mask")


if __name__ == "__main__":
unittest.main()