diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 18540abab0..8fe534ff11 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1061,12 +1061,11 @@ def _transform(self, index: int): if not self.transform: return data - transformed_data = apply_transform(self.transform, data) + result = apply_transform(self.transform, data) - if not isinstance(transformed_data, dict): - raise AssertionError("With a dict supplied to apply_transform a single dict return is expected.") - - return transformed_data + if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)): + return result + raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.") class CSVDataset(Dataset):