diff --git a/.codecov.yml b/.codecov.yml index e5b9d64f..c6370468 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -11,4 +11,5 @@ coverage: ignore: - "tests/" - "*/__init.py" + - "_version.py" diff --git a/requirements.txt b/requirements.txt index 3b7480fb..5b2778d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,6 @@ numpy torch +threadpoolctl +pandas +sklearn +tqdm \ No newline at end of file diff --git a/rising/loading/__init__.py b/rising/loading/__init__.py new file mode 100644 index 00000000..135e4153 --- /dev/null +++ b/rising/loading/__init__.py @@ -0,0 +1,4 @@ +from rising.loading.collate import numpy_collate +from rising.loading.dataset import Dataset +from rising.loading.loader import DataLoader +from rising.loading.debug_mode import get_debug_mode, set_debug_mode, switch_debug_mode diff --git a/rising/loading/collate.py b/rising/loading/collate.py new file mode 100644 index 00000000..152933d4 --- /dev/null +++ b/rising/loading/collate.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +import collections.abc +from typing import Any + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + + +def numpy_collate(batch: Any) -> Any: + """ + function to collate the samples to a whole batch of numpy arrays. + PyTorch Tensors, scalar values and sequences will be casted to arrays + automatically. + + Parameters + ---------- + batch : Any + a batch of samples. In most cases this is either a sequence, + a mapping or a mixture of them + + Returns + ------- + Any + collated batch with optionally converted type (to numpy array) + + """ + elem = batch[0] + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, torch.Tensor): + return numpy_collate([b.detach().cpu().numpy() for b in batch]) + elif isinstance(elem, float) or isinstance(elem, int): + return np.array(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: numpy_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return type(elem)(*(numpy_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + transposed = zip(*batch) + return [numpy_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(type(elem))) \ No newline at end of file diff --git a/rising/loading/container.py b/rising/loading/container.py new file mode 100644 index 00000000..9677f156 --- /dev/null +++ b/rising/loading/container.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import pandas as pd +import typing +import pathlib +from collections import defaultdict + +from rising.loading.dataset import Dataset +from rising.loading.splitter import SplitType + + +class DataContainer: + def __init__(self, dataset: Dataset): + """ + Handles the splitting of datasets from different sources + + Parameters + ---------- + dataset : dataset + the dataset to split + kwargs + """ + self._dataset = dataset + self._dset = {} + self._fold = None + super().__init__() + + def split_by_index(self, split: SplitType) -> None: + """ + Splits dataset by a given split-dict + + Parameters + ---------- + split : dict + a dictionary containing tuples of strings and lists of indices + for each split + + """ + for key, idx in split.items(): + self._dset[key] = self._dataset.get_subset(idx) + + def kfold_by_index(self, splits: typing.Iterable[SplitType]): + """ + Produces kfold splits based on the given indices. + + Parameters + ---------- + splits : list + list containing split dicts for each fold + + Yields + ------ + DataContainer + the data container with updated dataset splits + + """ + for fold, split in enumerate(splits): + self.split_by_index(split) + self._fold = fold + yield self + self._fold = None + + def split_by_csv(self, path: typing.Union[pathlib.Path, str], + index_column: str, **kwargs) -> None: + """ + Splits a dataset by splits given in a CSV file + + Parameters + ---------- + path : str, pathlib.Path + the path to the csv file + index_column : str + the label of the index column + **kwargs : + additional keyword arguments (see :func:`pandas.read_csv` for + details) + + """ + df = pd.read_csv(path, **kwargs) + df = df.set_index(index_column) + col = list(df.columns) + self.split_by_index(self._read_split_from_df(df, col[0])) + + def kfold_by_csv(self, path: typing.Union[pathlib.Path, str], + index_column: str, **kwargs) -> DataContainer: + """ + Produces kfold splits based on the given csv file. + + Parameters + ---------- + path : str, pathlib.Path + the path to the csv file + index_column : str + the label of the index column + **kwargs : + additional keyword arguments (see :func:`pandas.read_csv` for + details) + + Yields + ------ + DataContainer + the data container with updated dataset splits + + """ + df = pd.read_csv(path, **kwargs) + df = df.set_index(index_column) + folds = list(df.columns) + splits = [self._read_split_from_df(df, fold) for fold in folds] + yield from self.kfold_by_index((splits)) + + @staticmethod + def _read_split_from_df(df: pd.DataFrame, col: str) -> SplitType: + """ + Helper function to read a split from a given data frame + + Parameters + ---------- + df : pandas.DataFrame + the dataframe containing the split + col : str + the column inside the data frame containing the split + + Returns + ------- + dict + a dictionary of lists. Contains a string-list-tuple per split + + """ + split = defaultdict(list) + for index, row in df[[col]].iterrows(): + split[str(row[col])].append(index) + return split + + @property + def dset(self) -> Dataset: + if not self._dset: + raise AttributeError("No Split found.") + else: + return self._dset + + @property + def fold(self) -> int: + if self._fold is None: + raise AttributeError( + "Fold not specified. Call `kfold_by_index` first.") + else: + return self._fold + + +class DataContainerID(DataContainer): + """ + Data Container Class for datasets with an ID + """ + + def split_by_id(self, split: SplitType) -> None: + """ + Splits the internal dataset by the given splits + + Parameters + ---------- + split : dict + dictionary containing a string-list tuple per split + + """ + split_idx = defaultdict(list) + for key, _id in split.items(): + for _i in _id: + split_idx[key].append(self._dataset.get_index_by_id(_i)) + return super().split_by_index(split_idx) + + def kfold_by_id( + self, + splits: typing.Iterable[SplitType]): + """ + Produces kfold splits by an ID + + Parameters + ---------- + splits : list + list of dicts each containing the splits for a separate fold + + Yields + ------ + DataContaimnerID + the data container with updated internal datasets + + """ + for fold, split in enumerate(splits): + self.split_by_id(split) + self._fold = fold + yield self + self._fold = None + + def split_by_csv_id(self, path: typing.Union[pathlib.Path, str], + id_column: str, **kwargs) -> None: + """ + Splits the internal dataset by a given id column in a given csv file + + Parameters + ---------- + path : str or pathlib.Path + the path to the csv file + id_column : str + the key of the id_column + **kwargs : + additionalm keyword arguments (see :func:`pandas.read_csv` for + details) + + """ + df = pd.read_csv(path, **kwargs) + df = df.set_index(id_column) + col = list(df.columns) + return self.split_by_id(self._read_split_from_df(df, col[0])) + + def kfold_by_csv_id(self, path: typing.Union[pathlib.Path, str], + id_column: str, **kwargs): + """ + Produces kfold splits by an ID column of a given csv file + + Parameters + ---------- + path : str or pathlib.Path + the path to the csv file + id_column : str + the key of the id_column + **kwargs : + additionalm keyword arguments (see :func:`pandas.read_csv` for + details) + + Yields + ------ + DataContaimnerID + the data container with updated internal datasets + + """ + df = pd.read_csv(path, **kwargs) + df = df.set_index(id_column) + folds = list(df.columns) + splits = [self._read_split_from_df(df, fold) for fold in folds] + yield from self.kfold_by_id((splits)) + + def save_split_to_csv_id(self, + path: typing.Union[pathlib.Path, str], + id_key: str, + split_column: str = 'split', + **kwargs) -> None: + """ + Saves a split top a given csv id + + Parameters + ---------- + path : str or pathlib.Path + the path of the csv file + id_key : str + the id key inside the csv file + split_column : str + the name of the split_column inside the csv file + **kwargs : + additional keyword arguments (see :meth:`pd.DataFrame.to_csv` + for details) + + """ + split_dict = {str(id_key): [], str(split_column): []} + for key, item in self._dset.items(): + for sample in item: + split_dict[str(id_key)].append(sample[id_key]) + split_dict[str(split_column)].append(str(key)) + pd.DataFrame(split_dict).to_csv(path, **kwargs) diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py new file mode 100644 index 00000000..f881901f --- /dev/null +++ b/rising/loading/dataset.py @@ -0,0 +1,467 @@ +from __future__ import annotations + +import os +import typing +import pathlib +from functools import partial +from tqdm import tqdm +import warnings + +from torch.utils.data import Dataset as TorchDset +from rising.loading.debug_mode import get_debug_mode +from rising import AbstractMixin +from torch.multiprocessing import Pool + + +class Dataset(TorchDset): + """ + Extension of PyTorch's Datasets by a ``get_subset`` method which returns a + sub-dataset. + """ + + def get_subset(self, indices: typing.Sequence[int]) -> SubsetDataset: + """ + Returns a Subset of the current dataset based on given indices + + Parameters + ---------- + indices : iterable + valid indices to extract subset from current dataset + + Returns + ------- + :class:`SubsetDataset` + the subset + """ + # extract other important attributes from current dataset + kwargs = {} + + for key, val in vars(self).items(): + if not (key.startswith("__") and key.endswith("__")): + + if key == "data": + continue + kwargs[key] = val + + old_getitem = self.__class__.__getitem__ + subset_data = [self[idx] for idx in indices] + + return SubsetDataset(subset_data, old_getitem, **kwargs) + + +class SubsetDataset(Dataset): + """ + A Dataset loading the data, which has been passed + in it's ``__init__`` by it's ``_sample_fn`` + """ + + def __init__(self, data: typing.Sequence, old_getitem: typing.Callable, + **kwargs): + """ + Parameters + ---------- + data : sequence + data to load (subset of original data) + old_getitem : function + get item method of previous dataset + **kwargs : + additional keyword arguments (are set as class attribute) + """ + super().__init__() + + self.data = data + self._old_getitem = old_getitem + + for key, val in kwargs.items(): + setattr(self, key, val) + + def __getitem__(self, index: int) -> typing.Union[typing.Dict, typing.Any]: + """ + returns single sample corresponding to ``index`` via the old get_item + Parameters + ---------- + index : int + index specifying the data to load + + Returns + ------- + Any, dict + can be any object containing a single sample, + but is often a dict-like. + """ + return self._old_getitem(self, index) + + def __len__(self) -> int: + """ + returns the length of the dataset + + Returns + ------- + int + number of samples + """ + return len(self.data) + + +class CacheDataset(Dataset): + def __init__(self, + data_path: typing.Union[typing.Union[pathlib.Path, str], list], + load_fn: typing.Callable, + mode: str = "append", + num_workers: int = None, + verbose=False, + **load_kwargs): + """ + A dataset to preload all the data and cache it for the entire + lifetime of this class. + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + mode : str + whether to append the sample to a list or to extend the list by + it. Supported modes are: :param:`append` and :param:`extend`. + Default: ``append`` + num_workers : int, optional + the number of workers to use for preloading. ``0`` means, all the + data will be loaded in the main process, while ``None`` means, + the number of processes will default to the number of logical + cores. + verbose : bool + whether to show the loading progress. Mutually exclusive with + ``num_workers is not None and num_workers > 0`` + **load_kwargs : + additional keyword arguments. Passed directly to :param:`load_fn` + """ + super().__init__() + + if get_debug_mode() and (num_workers is None or num_workers > 0): + warnings.warn("The debug mode has been activated. " + "Falling back to num_workers = 0", UserWarning) + num_workers = 0 + + if (num_workers is None or num_workers > 0) and verbose: + warnings.warn("Verbosity is mutually exclusive with " + "num_workers > 0. Setting it to False instead.", UserWarning) + verbose = False + + self._num_workers = num_workers + self._verbosity = verbose + + self._load_fn = load_fn + self._load_kwargs = load_kwargs + self.data = self._make_dataset(data_path, mode) + + def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], list], + mode: str) -> typing.List[dict]: + """ + Function to build the entire dataset + + Parameters + ---------- + path : str, Path or list + the path(s) containing the data samples + mode : str + whether to append or extend the dataset by the loaded sample + + Returns + ------- + list + the loaded data + + """ + data = [] + if not isinstance(path, list): + assert os.path.isdir(path), '%s is not a valid directory' % path + path = [os.path.join(path, p) for p in os.listdir(path)] + + # sort for reproducibility (this is done explicitly since the listdir + # function does not return the paths in an ordered way on all OS) + path = sorted(path) + + # add loading kwargs + load_fn = partial(self._load_fn, **self._load_kwargs) + + # multiprocessing dispatch + if self._num_workers is None or self._num_workers > 0: + with Pool() as p: + _data = p.map(load_fn, path) + else: + if self._verbosity: + path = tqdm(path, unit='samples', desc="Loading Samples") + _data = map(load_fn, path) + + for sample in _data: + self._add_item(data, sample, mode) + return data + + @staticmethod + def _add_item(data: list, item: typing.Any, mode: str) -> None: + """ + Adds items to the given data list. The actual way of adding these + items depends on :param:`mode` + + Parameters + ---------- + data : list + the list containing the already loaded data + item : Any + the current item which will be added to the list + mode : str + the string specifying the mode of how the item should be added. + + """ + _mode = mode.lower() + + if _mode == 'append': + data.append(item) + elif _mode == 'extend': + data.extend(item) + else: + raise TypeError(f"Unknown mode detected: {mode} not supported.") + + def __getitem__(self, index: int) -> typing.Union[typing.Any, typing.Dict]: + """ + Making the whole Dataset indexeable. + + Parameters + ---------- + index : int + the integer specifying which sample to return + + Returns + ------- + Any, Dict + can be any object containing a single sample, but in practice is + often a dict + + """ + return self.data[index] + + def __len__(self) -> int: + """ + Length of dataset + + Returns + ------- + int + number of elements + """ + return len(self.data) + + +class LazyDataset(Dataset): + def __init__(self, data_path: typing.Union[str, list], + load_fn: typing.Callable, + **load_kwargs): + """ + A dataset to load all the data just in time. + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + load_kwargs: + additional keyword arguments (passed to :param:`load_fn`) + """ + super().__init__() + self._load_fn = load_fn + self._load_kwargs = load_kwargs + self.data = self._make_dataset(data_path) + + def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], + list]) -> typing.List[dict]: + """ + Function to build the entire dataset + + Parameters + ---------- + path : str, Path or list + the path(s) containing the data samples + + Returns + ------- + list + the loaded data + + """ + if not isinstance(path, list): + assert os.path.isdir(path), '%s is not a valid directory' % path + path = [os.path.join(path, p) for p in os.listdir(path)] + + sorted(path) + return path + + def __getitem__(self, index: int) -> dict: + """ + Making the whole Dataset indexeable. Loads the necessary sample. + + Parameters + ---------- + index : int + the integer specifying which sample to load and return + + Returns + ------- + Any, Dict + can be any object containing a single sample, but in practice is + often a dict + + """ + data_dict = self._load_fn(self.data[index], + **self._load_kwargs) + return data_dict + + def __len__(self) -> int: + """ + Length of dataset + + Returns + ------- + int + number of elements + """ + return len(self.data) + + +class IDManager(AbstractMixin): + def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): + """ + Helper class to add additional functionality to Datasets + + Parameters + ---------- + id_key : str + the id key to cache + cache_ids : bool + whether to cache the ids + **kwargs : + additional keyword arguments + """ + super().__init__(**kwargs) + self.id_key = id_key + self._cached_ids = None + + if cache_ids: + self.cache_ids() + + def cache_ids(self) -> None: + """ + Caches the IDs + + """ + self._cached_ids = { + sample[self.id_key]: idx for idx, sample in enumerate(self)} + + def _find_index_iterative(self, id: str) -> int: + """ + Checks for the next index matching the given id + + Parameters + ---------- + id : str + the id to get the index for + + Returns + ------- + int + the returned index + + Raises + ------ + KeyError + no index matching the given id + + """ + for idx, sample in enumerate(self): + if sample[self.id_key] == id: + return idx + raise KeyError(f"ID {id} not found.") + + def get_sample_by_id(self, id: str) -> dict: + """ + Fetches the sample to a corresponding ID + + Parameters + ---------- + id : str + the id specifying the sample to return + + Returns + ------- + dict + the sample corresponding to the given ID + + """ + return self[self.get_index_by_id(id)] + + def get_index_by_id(self, id: str) -> int: + """ + Returns the index corresponding to a given id + + Parameters + ---------- + id : str + the id specifying the index of which sample should be returned + + Returns + ------- + int + the index of the sample matching the given id + + """ + if self._cached_ids is not None: + return self._cached_ids[id] + else: + return self._find_index_iterative(id) + + +class CacheDatasetID(IDManager, CacheDataset): + def __init__(self, data_path, load_fn, id_key, cache_ids=True, + **kwargs): + """ + Caching version of ID Dataset + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + id_key : str + the id key to cache + cache_ids : bool + whether to cache the ids + **kwargs : + additional keyword arguments + """ + super().__init__(data_path=data_path, load_fn=load_fn, id_key=id_key, + cache_ids=cache_ids, **kwargs) + + +class LazyDatasetID(IDManager, LazyDataset): + def __init__(self, data_path, load_fn, id_key, cache_ids=True, + **kwargs): + """ + Lazy version of ID Dataset + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + id_key : str + the id key to cache + cache_ids : bool + whether to cache the ids + **kwargs : + additional keyword arguments + """ + super().__init__(data_path=data_path, load_fn=load_fn, id_key=id_key, + cache_ids=cache_ids, **kwargs) diff --git a/rising/loading/debug_mode.py b/rising/loading/debug_mode.py new file mode 100644 index 00000000..a75485e1 --- /dev/null +++ b/rising/loading/debug_mode.py @@ -0,0 +1,40 @@ +__DEBUG_MODE = False + +# Functions to get and set the internal __DEBUG_MODE variable. This variable +# currently only defines whether to use multiprocessing or not. At the moment +# this is only used inside the DataManager, which either returns a +# MultiThreadedAugmenter or a SingleThreadedAugmenter depending on the current +# debug mode. +# All other functions using multiprocessing should be aware of this and +# implement a functionality without multiprocessing +# (even if this slows down things a lot!). + + +def get_debug_mode(): + """ + Getter function for the current debug mode + Returns + ------- + bool + current debug mode + """ + return __DEBUG_MODE + + +def switch_debug_mode(): + """ + Alternates the current debug mode + """ + set_debug_mode(not get_debug_mode()) + + +def set_debug_mode(mode: bool): + """ + Sets a new debug mode + Parameters + ---------- + mode : bool + the new debug mode + """ + global __DEBUG_MODE + __DEBUG_MODE = mode \ No newline at end of file diff --git a/rising/loading/loader.py b/rising/loading/loader.py new file mode 100644 index 00000000..b67a6507 --- /dev/null +++ b/rising/loading/loader.py @@ -0,0 +1,243 @@ +from __future__ import annotations +from typing import Callable, Mapping, Sequence, Union, Any +from torch.utils.data._utils.collate import default_convert +from torch.utils.data import DataLoader as _DataLoader, Sampler +from torch.utils.data.dataloader import \ + _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter as \ + __MultiProcessingDataLoaderIter +from rising.loading.debug_mode import get_debug_mode +from functools import partial +from rising.loading.dataset import Dataset +from threadpoolctl import threadpool_limits +import numpy as np + + +class DataLoader(_DataLoader): + def __init__(self, dataset: Union[Sequence, Dataset], + batch_size: int = 1, shuffle: bool = False, + batch_transforms: Callable = None, sampler: Sampler = None, + batch_sampler: Sampler = None, num_workers: int = 0, + collate_fn: Callable = None, + pin_memory: bool = False, drop_last: bool = False, + timeout: Union[int, float] = 0, + worker_init_fn: Callable = None, + multiprocessing_context=None, + auto_convert: bool = True): + """ + A Dataloader introducing batch-transforms, numpy seeds for worker + processes and compatibility to the debug mode + + Note + ---- + For Reproducibility numpy and pytorch must be seeded in the main + process, as these frameworks will be used to generate their own seeds + for each worker. + + Note + ---- + ``len(dataloader)`` heuristic is based on the length of the sampler + used. When :attr:`dataset` is an + :class:`~torch.utils.data.IterableDataset`, an infinite sampler is + used, whose :meth:`__len__` is not implemented, because the actual + length depends on both the iterable as well as multi-process loading + configurations. So one should not query this method unless they work + with a map-style dataset. See `Dataset Types`_ for more details on + these two types of datasets. + + Warning + ------- + If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + Parameters + ---------- + dataset : Dataset + dataset from which to load the data + batch_size : int, optional + how many samples per batch to load (default: ``1``). + shuffle : bool, optional + set to ``True`` to have the data reshuffled at every epoch + (default: ``False``) + batch_transforms : callable, optional + transforms which can be applied to a whole batch. + Usually this accepts either mappings or sequences and returns the + same type containing transformed elements + sampler : torch.utils.data.Sampler, optional + defines the strategy to draw samples from + the dataset. If specified, :attr:`shuffle` must be ``False``. + batch_sampler : torch.utils.data.Sampler, optional + like :attr:`sampler`, but returns a batch of + indices at a time. Mutually exclusive with :attr:`batch_size`, + :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. + num_workers : int, optional + how many subprocesses to use for data loading. + ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn : callable, optional + merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory : bool, optional + If ``True``, the data loader will copy Tensors + into CUDA pinned memory before returning them. If your data + elements are a custom type, or your :attr:`collate_fn` returns a + batch that is a custom type, see the example below. + drop_last : bool, optional + set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. + If ``False`` and the size of dataset is not divisible by the batch + size, then the last batch will be smaller. (default: ``False``) + timeout : numeric, optional + if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn : callable, optional + If not ``None``, this will be called on each + worker subprocess with the worker id + (an int in ``[0, num_workers - 1]``) as input, after seeding and + before data loading. (default: ``None``) + auto_convert : bool, optional + if set to ``True``, the batches will always be transformed to + torch.Tensors, if possible. (default: ``True``) + """ + + super().__init__(dataset=dataset, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, + collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, + worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context) + + self.collate_fn = BatchTransformer(self.collate_fn, batch_transforms, + auto_convert) + + def __iter__(self) -> Union[_SingleProcessDataLoaderIter, + _MultiProcessingDataLoaderIter]: + if self.num_workers == 0 or get_debug_mode(): + return _SingleProcessDataLoaderIter(self) + else: + return _MultiProcessingDataLoaderIter(self) + + +class BatchTransformer(object): + """ + A callable wrapping the collate_fn to enable transformations on a + batch-basis. + """ + + def __init__(self, collate_fn: Callable, transforms: Callable = None, + auto_convert: bool = True): + """ + Parameters + ---------- + collate_fn : callable, optional + merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + transforms : callable, optional + transforms which can be applied to a whole batch. + Usually this accepts either mappings or sequences and returns the + same type containing transformed elements + auto_convert : bool, optional + if set to ``True``, the batches will always be transformed to + torch.Tensors, if possible. (default: ``True``) + """ + + self._collate_fn = collate_fn + self._transforms = transforms + self._auto_convert = auto_convert + + def __call__(self, *args, **kwargs) -> Any: + batch = self._collate_fn(*args, **kwargs) + + if self._transforms is not None: + if isinstance(batch, Mapping): + batch = self._transforms(**batch) + elif isinstance(batch, Sequence): + batch = self._transforms(*batch) + else: + batch = self._transforms(batch) + + if self._auto_convert: + batch = default_convert(batch) + + return batch + + +class _MultiProcessingDataLoaderIter(__MultiProcessingDataLoaderIter): + # NOTE [ Numpy Seeds ] + # This class is a subclass of + # ``torch.utils.data.dataloader._MultiProcessingDataLoaderIter``` and only + # adds some additional logic to provide differnt seeds for numpy in + # each worker. These seeds are based on a base seed, which itself get's + # generated by numpy. So to ensure reproducibility, numpy must be seeded + # in the main process. + def __init__(self, loader): + """ + Iterator over Dataloader. Handles the complete multiprocessing + + Parameters + ---------- + loader : DataLoader + the dataloader instance to iterate over + """ + + try: + import numpy as np + # generate numpy seed. The range comes so that the seed in each + # worker (which is this baseseed plus the worker id) is always an + # uint32. This is because numpy only accepts uint32 as valid seeds + npy_seed = np.random.randint(0, (2 ** 32) - (1 + loader.num_workers)) + except ImportError: + # we don't generate a numpy seed here with torch, since we don't + # need one; if the import fails in the main process it should + # also fail in child processes + npy_seed = None + + old_worker_init = loader.worker_init_fn + + if npy_seed is None: + new_worker_init_fn = old_worker_init + else: + new_worker_init_fn = partial(_seed_npy_before_worker_init, + seed=npy_seed, + worker_init_fn=old_worker_init) + loader.worker_init_fn = new_worker_init_fn + + with threadpool_limits(limits=1, user_api='blas'): + super().__init__(loader) + + # reset worker_init_fn once the workers have been startet to reset + # to original state for next epoch + loader.worker_init_fn = old_worker_init + + +def _seed_npy_before_worker_init(worker_id: int, seed: int, + worker_init_fn: Callable = None): + """ + Wrapper Function to wrap the existing worker_init_fn and seed numpy before + calling the actual ``worker_init_fn`` + + Parameters + ---------- + worker_id : int + the number of the worker + seed : int32 + the base seed in a range of [0, 2**32 - (1 + ``num_workers``)]. + The range ensures, that the whole seed, which consists of the base + seed and the ``worker_id``, can still be represented as a unit32, + as it needs to be for numpy seeding + worker_init_fn : callable, optional + will be called with the ``worker_id`` after seeding numpy if it is not + ``None`` + """ + try: + import numpy as np + np.random.seed(seed + worker_id) + except ImportError: + pass + + if worker_init_fn is not None: + return worker_init_fn(worker_id) diff --git a/rising/loading/splitter.py b/rising/loading/splitter.py new file mode 100644 index 00000000..019407d0 --- /dev/null +++ b/rising/loading/splitter.py @@ -0,0 +1,387 @@ +import copy +import typing +import logging +import warnings +from sklearn.model_selection import train_test_split, GroupShuffleSplit, \ + KFold, GroupKFold, StratifiedKFold + +from rising.loading.dataset import Dataset + +logger = logging.getLogger(__file__) + +SplitType = typing.Dict[str, list] + + +class Splitter: + def __init__(self, + dataset: Dataset, + val_size: typing.Union[int, float] = 0, + test_size: typing.Union[int, float] = 0): + """ + Splits a dataset by several options + + Parameters + ---------- + dataset : Dataset + the dataset to split + val_size : float, int + the validation split; + if float this will be interpreted as a percentage of the + dataset + if int this will be interpreted as the number of samples + test_size : float, int , optionally + the size of the validation split; If provided it must be int or + float. + if float this will be interpreted as a percentage of the + dataset + if int this will be interpreted as the number of samples + if not provided or explicitly set to None, no testset will be + created + """ + super().__init__() + if val_size == 0 and test_size == 0: + warnings.warn("Can not perform splitting if val and test size is 0.") + + self._dataset = dataset + self._total_num = len(self._dataset) + self._idx = list(range(self._total_num)) + self._val = val_size + self._test = test_size + + self._convert_prop_to_num() + self._check_sizes() + + def _check_sizes(self): + """ + Checks if the given sizes are valid for splitting + + Raises + ------ + ValueError + at least one of the sizes is invalid + + """ + if self._total_num <= 0: + raise ValueError("Size must be larger than zero, not " + "{}".format(self._total_num)) + if self._val <= 0: + raise ValueError("Size must be larger than zero, not " + "{}".format(self._val)) + if self._test < 0: + raise ValueError("Size must be larger than zero, not " + "{}".format(self._test)) + + if self._total_num < self._val + self._test: + raise ValueError("Val + test size must be smaller than total, " + "not {}".format(self._val + self._test)) + + def index_split(self, **kwargs) -> SplitType: + """ + Splits the dataset's indices in a random way + + Parameters + ---------- + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + dict + the dictionary containing the corresponding splits under the + keys 'train', 'val' and (optionally) 'test' + + """ + split_dict = {} + split_dict["train"], tmp = train_test_split( + self._idx, test_size=self._val + self._test, + **kwargs) + + if self._test > 0: + # update stratified if provided, + # necessary for index_split_stratified + if 'stratify' in kwargs: + kwargs['stratify'] = [kwargs['stratify'][_i] for _i in tmp] + split_dict["val"], split_dict["test"] = train_test_split( + tmp, test_size=self._val, **kwargs) + else: + split_dict["val"] = tmp + self.log_split(split_dict, "Created Single Split with:") + return split_dict + + def index_split_stratified(self, stratify_key: str = "label", **kwargs) -> SplitType: + """ + Splits the dataset's indices in a stratified way + + Parameters + ---------- + stratify_key : str + the key specifying which value of each sample to use for + stratification + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + dict + the dictionary containing the corresponding splits under the + keys 'train', 'val' and (optionally) 'test' + + """ + stratify = [d[stratify_key] for d in self._dataset] + return self.index_split(stratify=stratify, **kwargs) + + def index_split_grouped(self, groups_key: str = "id", **kwargs) -> SplitType: + """ + Splits the dataset's indices in a stratified way + + Parameters + ---------- + groups_key : str + the key specifying which value of each sample to use for + grouping + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + dict + the dictionary containing the corresponding splits under the + keys 'train', 'val' and (optionally) 'test' + + Warnings + -------- + Shuffling cannot be deactivated + """ + split_dict = {} + groups = [d[groups_key] for d in self._dataset] + + gsp = GroupShuffleSplit( + n_splits=1, test_size=self._val + self._test, **kwargs) + split_dict["train"], tmp = next(gsp.split(self._idx, groups=groups)) + + if self._test > 0: + groups_tmp = [groups[_i] for _i in tmp] + gsp = GroupShuffleSplit(n_splits=1, test_size=self._val, **kwargs) + split_dict["val"], split_dict["test"] = next( + gsp.split(tmp, groups=groups_tmp)) + else: + split_dict["val"] = tmp + self.log_split(split_dict, "Created Single Split with:") + return split_dict + + def index_kfold_fixed_test(self, **kwargs) -> typing.Iterable[SplitType]: + """ + Calculates splits for a random kfold with given testset. + If :param:`test_size` is zero, a normal kfold is generated + + Parameters + ---------- + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + list + list containing one dict for each fold each containing the + corresponding splits under the keys 'train', 'val' and 'test' + + """ + splits = [] + + idx_dict = self.index_split(**kwargs) + train_val_idx = idx_dict.pop("train") + idx_dict.pop("val") + + logger.info("Creating {} folds.".format(self.val_folds)) + kf = KFold(n_splits=self.val_folds, **kwargs) + _fold = 0 + for train_idx, val_idx in kf.split(train_val_idx): + splits.append(self._copy_and_fill_dict( + idx_dict, train=train_idx, val=val_idx)) + self.log_split(splits[-1], f"Created Fold{_fold}.") + _fold += 1 + return splits + + def index_kfold_fixed_test_stratified( + self, + stratify_key: str = "label", + **kwargs) -> typing.Iterable[SplitType]: + """ + Calculates splits for a stratified kfold with given testset + If :param:`test_size` is zero, a normal kfold is generated + + Parameters + ---------- + stratify_key : str + the key specifying which value of each sample to use for + stratification + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + list + list containing one dict for each fold each containing the + corresponding splits under the keys 'train', 'val' and 'test' + + """ + splits = [] + + idx_dict = self.index_split_stratified(**kwargs) + train_val_idx = idx_dict.pop("train") + idx_dict.pop("val") + train_val_stratify = [ + self._dataset[_i][stratify_key] for _i in train_val_idx] + + logger.info("Creating {} folds.".format(self.val_fols)) + kf = StratifiedKFold(n_splits=self.val_folds, **kwargs) + _fold = 0 + for train_idx, val_idx in kf.split(train_val_idx, train_val_stratify): + splits.append(self._copy_and_fill_dict( + idx_dict, train=train_idx, val=val_idx)) + self.log_split(splits[-1], f"Created Fold{_fold}.") + _fold += 1 + return splits + + def index_kfold_fixed_test_grouped(self, groups_key: str = "id", + **kwargs) -> typing.Iterable[SplitType]: + """ + Calculates splits for a stratified kfold with given testset + If :param:`test_size` is zero, a normal kfold is generated + + Parameters + ---------- + groups_key : str + the key specifying which value of each sample to use for + grouping + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + list + list containing one dict for each fold each containing the + corresponding splits under the keys 'train', 'val' and 'test' + + """ + splits = [] + + idx_dict = self.index_split_grouped(**kwargs) + train_val_idx = idx_dict.pop("train") + idx_dict.pop("val") + train_val_groups = [ + self._dataset[_i][groups_key] for _i in train_val_idx] + + logger.info("Creating {} folds.".format(self.val_fols)) + kf = GroupKFold(n_splits=self.val_folds, **kwargs) + _fold = 0 + for train_idx, val_idx in kf.split( + train_val_idx, groups=train_val_groups): + splits.append(self._copy_and_fill_dict( + idx_dict, train=train_idx, val=val_idx)) + self.log_split(splits[-1], f"Created Fold{_fold}.") + _fold += 1 + return splits + + def _convert_prop_to_num(self, attributes: tuple = ("_val", "_test") + ) -> None: + """ + Converts all given attributes from percentages to number of samples + if necessary + + Parameters + ---------- + attributes : tuple + tuple of strings containing the attribute names + + """ + for attr in attributes: + value = getattr(self, attr) + if 0 < value < 1: + setattr(self, attr, value * self._total_num) + + @staticmethod + def log_split(dict_like: dict, desc: str = None) -> None: + """ + Logs the new created split + + Parameters + ---------- + dict_like : dict + the splits (usually this dict contains the keys 'train', 'val' + and (optionally) 'test' and a list of indices for each of them + desc : str, optional + the descriptor string to log before the actual splits + + """ + if desc is not None: + logger.info(desc) + for key, item in dict_like.items(): + logger.info(f"{str(key).upper()} contains {len(item)} indices.") + + @staticmethod + def _copy_and_fill_dict(dict_like: dict, **kwargs) -> dict: + """ + copies the dict and adds the kwargs to the copy + + Parameters + ---------- + dict_like : dict + the dict to copy and fill + **kwargs : + the keyword argument added to the dict copy + + Returns + ------- + dict + the copied and filled dict + + """ + new_dict = copy.deepcopy(dict_like) + new_dict.update(kwargs) + return new_dict + + @property + def dataset(self) -> Dataset: + return self._dataset + + @dataset.setter + def dataset(self, dset: Dataset): + self._dataset = dset + self._total_num = len(self._dataset) + self._idx = list(range(self._total_num)) + + @property + def val_size(self) -> int: + return self._val + + @val_size.setter + def val_size(self, value: typing.Union[int, float]): + self._val = value + self._convert_prop_to_num() + self._check_sizes() + + @property + def test_size(self) -> int: + return self._test + + @test_size.setter + def test_size(self, value: typing.Union[int, float]): + self._test = value + self._convert_prop_to_num() + self._check_sizes() + + @property + def folds(self) -> int: + return self.val_folds * self.test_folds + + @property + def val_folds(self) -> int: + return int(self._total_num // self._val) + + @property + def test_folds(self) -> int: + return int(self._total_num // self._test) \ No newline at end of file diff --git a/rising/utils/checktype.py b/rising/utils/checktype.py index 9eac016e..b0d09447 100644 --- a/rising/utils/checktype.py +++ b/rising/utils/checktype.py @@ -1,5 +1,18 @@ def check_scalar(x): + """ + Provide interface to check for scalars + + Parameters + ---------- + x: typing.Any + object to check for scalar + + Returns + ------- + bool + True if input is scalar + """ if isinstance(x, (int, float)): return True else: diff --git a/setup.py b/setup.py index 9a4636e8..5f613f7d 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from setuptools import setup, find_packages import versioneer import os -import re def resolve_requirements(file): @@ -27,7 +26,6 @@ def read_file(file): 'requirements.txt')) readme = read_file(os.path.join(os.path.dirname(__file__), "README.md")) -license = read_file(os.path.join(os.path.dirname(__file__), "LICENSE")) setup( @@ -35,14 +33,15 @@ def read_file(file): version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), packages=find_packages(), - # url='path/to/url', + url='https://github.com/phoenixdl/rising', test_suite="unittest", long_description=readme, long_description_content_type='text/markdown', install_requires=requirements, tests_require=["coverage"], python_requires=">=3.7", - author="Michael Baumgartner", - author_email="michael.baumgartner@rwth-aachen.de", - license=license, + author="PhoenixDL", + maintainer='Michael Baumgartner, Justus Schock', + maintainer_email='{michael.baumgartner, justus.schock}@rwth-aachen.de', + license='MIT', ) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..ac3ff41f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +from ._utils import LoadDummySample, DummyDataset, DummyDatasetID \ No newline at end of file diff --git a/tests/_utils.py b/tests/_utils.py new file mode 100644 index 00000000..ffa342a0 --- /dev/null +++ b/tests/_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +from rising.loading.dataset import CacheDataset, CacheDatasetID + + +class LoadDummySample: + def __init__(self, keys=('data', 'label'), sizes=((3, 128, 128), (3,)), + **kwargs): + super().__init__(**kwargs) + self.keys = keys + self.sizes = sizes + + def __call__(self, path, *args, **kwargs): + data = {_k: np.random.rand(*_s) + for _k, _s in zip(self.keys, self.sizes)} + data['id'] = f'sample{path}' + return data + + +class DummyDataset(CacheDataset): + def __init__(self, num_samples=10, load_fn=LoadDummySample(), + **load_kwargs): + super().__init__(list(range(num_samples)), load_fn, **load_kwargs) + + +class DummyDatasetID(CacheDatasetID): + def __init__(self, num_samples=10, load_fn=LoadDummySample(), + **load_kwargs): + super().__init__(list(range(num_samples)), load_fn, id_key="id", + **load_kwargs) diff --git a/tests/test_ops/__init__.py b/tests/loading/__init__.py similarity index 100% rename from tests/test_ops/__init__.py rename to tests/loading/__init__.py diff --git a/tests/loading/_src/kfold.csv b/tests/loading/_src/kfold.csv new file mode 100644 index 00000000..8a6f1ae7 --- /dev/null +++ b/tests/loading/_src/kfold.csv @@ -0,0 +1,7 @@ +index;fold0;fold1;fold2 +0;train;test;val +1;train;test;val +2;val;train;test +3;val;train;test +4;test;val;train +5;test;val;train \ No newline at end of file diff --git a/tests/loading/_src/split.csv b/tests/loading/_src/split.csv new file mode 100644 index 00000000..b1f886f6 --- /dev/null +++ b/tests/loading/_src/split.csv @@ -0,0 +1,7 @@ +index;fold0 +0;train +1;train +2;val +3;val +4;test +5;test \ No newline at end of file diff --git a/tests/loading/test_collate.py b/tests/loading/test_collate.py new file mode 100644 index 00000000..f2b641b0 --- /dev/null +++ b/tests/loading/test_collate.py @@ -0,0 +1,91 @@ +import torch +import unittest +from collections import namedtuple + +try: + import numpy as np +except ImportError: + np = None + +from rising.loading.collate import numpy_collate + + +class TestCollate(unittest.TestCase): + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_int(self): + arr = [1, 2, -1] + collated = numpy_collate(arr) + expected = np.array(arr) + self.assertTrue((collated == expected).all()) + self.assertEqual(collated.dtype, expected.dtype) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_float(self): + arr = [1.1, 2.3, -0.9] + collated = numpy_collate(arr) + expected = np.array(arr) + self.assertTrue((collated == expected).all()) + self.assertEqual(collated.dtype, expected.dtype) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_bool(self): + arr = [True, False] + collated = numpy_collate(arr) + self.assertTrue(all(collated == np.array(arr))) + self.assertEqual(collated.dtype, np.bool) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_str(self): + # Should be a no-op + arr = ['a', 'b', 'c'] + self.assertTrue((arr == numpy_collate(arr))) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_ndarray(self): + arr = [np.array(0), np.array(1), np.array(2)] + collated = numpy_collate(arr) + expected = np.array([0, 1, 2]) + self.assertTrue((collated == expected).all()) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_tensor(self): + arr = [torch.tensor(0), torch.tensor(1), torch.tensor(2)] + collated = numpy_collate(arr) + expected = np.array([0, 1, 2]) + self.assertTrue((collated == expected).all()) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_mapping(self): + arr = [{"a": np.array(0), "b": np.array(1)}] * 2 + collated = numpy_collate(arr) + expected = {"a": np.array([0, 0]), "b": np.array([1, 1])} + for key in expected.keys(): + self.assertTrue((collated[key] == expected[key]).all()) + self.assertEqual(len(expected.keys()), len(collated.keys())) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_sequence(self): + arr = [[np.array(0), np.array(1)], [np.array(0), np.array(1)]] + collated = numpy_collate(arr) + expected = [np.array([0, 0]), np.array([1, 1])] + for i in range(len(collated)): + self.assertTrue((collated[i] == expected[i]).all()) + self.assertEqual(len(expected), len(collated)) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_error(self): + with self.assertRaises(TypeError): + collated = numpy_collate([{"a", "b"}, {"a", "b"}]) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_named_tuple(self): + Point = namedtuple('Point', ['x', 'y']) + arr = [Point(0, 1), Point(2, 3)] + collated = numpy_collate(arr) + expected = Point(np.array([0, 2]), np.array([1, 3])) + self.assertTrue((collated.x == expected.x).all()) + self.assertTrue((collated.y == expected.y).all()) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/loading/test_container.py b/tests/loading/test_container.py new file mode 100644 index 00000000..449a4bcc --- /dev/null +++ b/tests/loading/test_container.py @@ -0,0 +1,133 @@ +import unittest +import os +from pathlib import Path +import numpy as np +import pandas as pd +from rising.loading.container import DataContainer, DataContainerID +from tests import DummyDataset, DummyDatasetID + + +class LoadDummySampleID: + def __init__(self, keys=('data', 'label'), sizes=((3, 128, 128), (3,)), + **kwargs): + super().__init__(**kwargs) + self.keys = keys + self.sizes = sizes + + def __call__(self, path, *args, **kwargs): + data = {_k: np.random.rand(*_s) + for _k, _s in zip(self.keys, self.sizes)} + data['id'] = int(path) + return data + + +class TestDataContainer(unittest.TestCase): + def setUp(self): + self.dset = DummyDataset(num_samples=6, + load_fn=(LoadDummySampleID()), + ) + self.dset_id = DummyDatasetID(num_samples=6, + load_fn=LoadDummySampleID(), + ) + self.split = {"train": [0, 1, 2], "val": [3, 4, 5]} + self.kfold = [{"train": [0, 1, 2], "val": [3, 4, 5]}, + {"val": [0, 1, 2], "train": [3, 4, 5]}] + + def test_empty_container(self): + container = DataContainer(self.dset) + with self.assertRaises(AttributeError): + container.dset["train"] + + with self.assertRaises(AttributeError): + container.fold + + def test_split_by_index(self): + container = DataContainer(self.dset) + container.split_by_index(self.split) + self.check_split(container) + + def check_split(self, container): + self._assert_split(container, [0, 1, 2], [3, 4, 5]) + + def test_kfold_by_index(self): + container = DataContainer(self.dset) + self.check_kfold(container.kfold_by_index(self.kfold)) + + def check_kfold(self, container_generator): + for container_fold in container_generator: + if container_fold.fold == 0: + self._assert_split(container_fold, [0, 1, 2], [3, 4, 5]) + elif container_fold.fold == 1: + self._assert_split(container_fold, [3, 4, 5], [0, 1, 2]) + else: + self.assertTrue(False, "Unknown Fold") + + def test_split_by_csv(self): + container = DataContainer(self.dset) + p = os.path.join(os.path.dirname(__file__), '_src', 'split.csv') + container.split_by_csv(p, 'index', sep=';') + self.check_split_csv(container) + + def check_split_csv(self, container): + self._assert_split(container, [0, 1], [2, 3], [4, 5]) + + def test_kfold_by_csv(self): + container = DataContainer(self.dset) + p = os.path.join(os.path.dirname(__file__), '_src', 'kfold.csv') + self.check_kfold_csv(container.kfold_by_csv(p, 'index', sep=';')) + + def check_kfold_csv(self, container_generator): + for container_fold in container_generator: + if container_fold.fold == 0: + self._assert_split(container_fold, [0, 1], [2, 3], [4, 5]) + elif container_fold.fold == 1: + self._assert_split(container_fold, [2, 3], [4, 5], [0, 1]) + elif container_fold.fold == 2: + self._assert_split(container_fold, [4, 5], [0, 1], [2, 3]) + else: + self.assertTrue(False, "Unknown Fold") + + def _assert_split(self, container, train, val=None, test=None): + self.assertEqual([d["id"] for d in container.dset["train"]], train) + if val is not None: + self.assertEqual([d["id"] for d in container.dset["val"]], val) + if test is not None: + self.assertEqual([d["id"] for d in container.dset["test"]], test) + + def test_split_by_index_id(self): + container = DataContainerID(self.dset_id) + container.split_by_id(self.split) + self.check_split(container) + + def test_kfold_by_index_id(self): + container = DataContainerID(self.dset_id) + self.check_kfold(container.kfold_by_id(self.kfold)) + + def test_split_by_csv_id(self): + container = DataContainerID(self.dset_id) + p = os.path.join(os.path.dirname(__file__), '_src', 'split.csv') + container.split_by_csv_id(p, 'index', sep=';') + self.check_split_csv(container) + + def test_kfold_by_csv_id(self): + container = DataContainerID(self.dset_id) + p = os.path.join(os.path.dirname(__file__), '_src', 'kfold.csv') + self.check_kfold_csv(container.kfold_by_csv_id(p, 'index', sep=';')) + + def test_save_split_to_csv_id(self): + container = DataContainerID(self.dset_id) + container.split_by_id(self.split) + p = os.path.join(os.path.dirname(__file__), + "_src", "test_generated_split.csv") + + container.save_split_to_csv_id(p, "id") + + df = pd.read_csv(p) + self.assertTrue(df["id"].to_list(), [1, 2, 3, 4, 5, 6]) + self.assertTrue(df["split"].to_list(), + ["train", "train", "train", "val", "val", "val"]) + os.remove(p) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/loading/test_dataset.py b/tests/loading/test_dataset.py new file mode 100644 index 00000000..1afcd82a --- /dev/null +++ b/tests/loading/test_dataset.py @@ -0,0 +1,202 @@ +import unittest +import os +import tempfile +import shutil +import pickle + +import numpy as np +from rising.loading.dataset import CacheDataset, LazyDataset, CacheDatasetID, \ + LazyDatasetID +from rising.loading import get_debug_mode, set_debug_mode + + +class LoadDummySample: + def __call__(self, path, *args, **kwargs): + data = {'data': np.random.rand(1, 256, 256), + 'label': np.random.randint(2), + 'id': f"sample{path}"} + return data + + +def pickle_save(path, data): + with open(path, "wb") as f: + pickle.dump(data, f) + + +def pickle_load(path, *args, **kwargs): + with open(path, "rb") as f: + pickle.load(f) + + +class TestBaseDatasetDir(unittest.TestCase): + def setUp(self) -> None: + self.dir = tempfile.mkdtemp(dir=os.path.dirname(os.path.realpath(__file__))) + loader = LoadDummySample() + for i in range(10): + pickle_save(os.path.join(self.dir, f"sample{i}.pkl"), loader(i)) + + def tearDown(self) -> None: + shutil.rmtree(self.dir) + + def test_cache_dataset_dir(self): + dataset = CacheDataset(self.dir, pickle_load, label_load_fct=None) + self.assertEqual(len(dataset), 10) + for i in dataset: + pass + + def test_lazy_dataset_dir(self): + dataset = LazyDataset(self.dir, pickle_load, label_load_fct=None) + self.assertEqual(len(dataset), 10) + for i in dataset: + pass + + +class TestBaseDataset(unittest.TestCase): + def setUp(self): + self.paths = list(range(10)) + + def test_cache_dataset(self): + dataset = CacheDataset(self.paths, LoadDummySample(), + label_load_fct=None) + self.assertEqual(len(dataset), 10) + self.check_dataset_access(dataset, [0, 5, 9]) + self.check_dataset_outside_access(dataset, [10, 20]) + self.check_dataset_iter(dataset) + + def test_cache_num_worker_warn(self): + set_debug_mode(True) + with self.assertWarns(UserWarning): + dataset = CacheDataset(self.paths, LoadDummySample(), + num_workers=4, + label_load_fct=None) + set_debug_mode(False) + + def test_cache_verbose_warn(self): + + with self.assertWarns(UserWarning): + dataset = CacheDataset(self.paths, LoadDummySample(), + num_workers=4, verbose=True, + label_load_fct=None) + + def test_cache_dataset_extend(self): + def load_mul_sample(path) -> list: + return [LoadDummySample()(path, None)] * 4 + + dataset = CacheDataset(self.paths, load_mul_sample, + num_workers=0, verbose=True, + mode='extend') + self.assertEqual(len(dataset), 40) + self.check_dataset_access(dataset, [0, 20, 39]) + self.check_dataset_outside_access(dataset, [40, 45]) + self.check_dataset_iter(dataset) + + def test_cache_dataset_mode_error(self): + with self.assertRaises(TypeError): + dataset = CacheDataset(self.paths, LoadDummySample(), + label_load_fct=None, mode="no_mode:P") + + def test_lazy_dataset(self): + dataset = LazyDataset(self.paths, LoadDummySample(), + label_load_fct=None) + self.assertEqual(len(dataset), 10) + self.check_dataset_access(dataset, [0, 5, 9]) + self.check_dataset_outside_access(dataset, [10, 20]) + self.check_dataset_iter(dataset) + + def check_dataset_access(self, dataset, inside_idx): + try: + for _i in inside_idx: + a = dataset[_i] + except BaseException: + self.assertTrue(False) + + def check_dataset_outside_access(self, dataset, outside_idx): + for _i in outside_idx: + with self.assertRaises(IndexError): + a = dataset[_i] + + def check_dataset_iter(self, dataset): + try: + j = 0 + for i in dataset: + self.assertIn('data', i) + self.assertIn('label', i) + j += 1 + assert j == len(dataset) + except BaseException: + raise AssertionError('Dataset iteration failed.') + + def test_subset_dataset(self): + idx = [0, 1, 2, 5, 6] + dataset = CacheDataset(self.paths, LoadDummySample(), + label_load_fct=None) + subset = dataset.get_subset(idx) + self.assertEqual(len(subset), len(idx)) + for _i, _idx in enumerate(idx): + self.assertEqual(subset[_i]["id"], dataset[_idx]["id"]) + with self.assertRaises(IndexError): + subset[len(idx)] + + +class TestDatasetID(unittest.TestCase): + def test_load_dummy_sample(self): + load_fn = LoadDummySample() + sample0 = load_fn(None, None) + self.assertIn("data", sample0) + self.assertIn("label", sample0) + self.assertTrue("id", "sample0") + + sample1 = load_fn(None, None) + self.assertIn("data", sample1) + self.assertIn("label", sample1) + self.assertTrue("id", "sample1") + + def check_dataset( + self, + dset_cls, + num_samples, + expected_len, + debug_num, + **kwargs): + load_fn = LoadDummySample() + dset = dset_cls(list(range(num_samples)), load_fn, + debug_num=debug_num, **kwargs) + self.assertEqual(len(dset), expected_len) + + def test_base_cache_dataset(self): + self.check_dataset(CacheDataset, num_samples=20, + expected_len=20, debug_num=10) + + def test_base_lazy_dataset_debug_off(self): + self.check_dataset(LazyDataset, num_samples=20, + expected_len=20, debug_num=10) + + def test_cachedataset_id(self): + load_fn = LoadDummySample() + dset = CacheDatasetID(list(range(10)), load_fn, + id_key="id", cache_ids=False) + self.check_dset_id(dset) + + def test_lazydataset_id(self): + load_fn = LoadDummySample() + dset = LazyDatasetID(list(range(10)), load_fn, + id_key="id", cache_ids=False) + self.check_dset_id(dset) + + def check_dset_id(self, dset): + idx = dset.get_index_by_id("sample1") + self.assertTrue(idx, 1) + + with self.assertRaises(KeyError): + idx = dset.get_index_by_id("sample10") + + sample5 = dset.get_sample_by_id("sample5") + self.assertTrue(sample5["id"], 5) + + dset.cache_ids() + sample6 = dset.get_sample_by_id("sample6") + self.assertTrue(sample6["id"], 6) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/loading/test_debug_mode.py b/tests/loading/test_debug_mode.py new file mode 100644 index 00000000..e573d73d --- /dev/null +++ b/tests/loading/test_debug_mode.py @@ -0,0 +1,21 @@ +import unittest + +from rising.loading import get_debug_mode, set_debug_mode, switch_debug_mode + + +class TestDebugMode(unittest.TestCase): + def test_debug_mode(self): + set_debug_mode(False) + self.assertFalse(get_debug_mode()) + set_debug_mode(True) + self.assertTrue(get_debug_mode()) + set_debug_mode(False) + self.assertFalse(get_debug_mode()) + + def test_switch_debug_mode(self): + set_debug_mode(False) + self.assertFalse(get_debug_mode()) + switch_debug_mode() + self.assertTrue(get_debug_mode()) + switch_debug_mode() + self.assertFalse(get_debug_mode()) diff --git a/tests/loading/test_loader.py b/tests/loading/test_loader.py new file mode 100644 index 00000000..b8c7b8de --- /dev/null +++ b/tests/loading/test_loader.py @@ -0,0 +1,87 @@ +import unittest +import torch +from typing import Sequence, Mapping +import numpy as np +from unittest.mock import Mock, patch +from torch.utils.data.dataloader import _SingleProcessDataLoaderIter + +from rising.loading.loader import _seed_npy_before_worker_init, DataLoader, \ + BatchTransformer, _MultiProcessingDataLoaderIter + + +class TestLoader(unittest.TestCase): + def test_seed_npy_before_worker_init(self): + expected_return = 100 + np.random.seed(1) + expected = np.random.rand(1) + worker_init = Mock(return_value=expected_return) + + output_return = _seed_npy_before_worker_init(worker_id=1, seed=0, + worker_init_fn=worker_init) + output = np.random.rand(1) + self.assertEqual(output, expected) + self.assertEqual(output_return, expected_return) + worker_init.assert_called_once_with(1) + + def test_seed_npy_before_worker_init_import_error(self): + with patch.dict('sys.modules', {'numpy': None}): + expected_return = 100 + worker_init = Mock(return_value=expected_return) + output_return = _seed_npy_before_worker_init(worker_id=1, seed=0, + worker_init_fn=worker_init) + self.assertEqual(output_return, expected_return) + worker_init.assert_called_once_with(1) + + def check_batch_transformer(self, collate_output): + collate = Mock(return_value=collate_output) + transforms = Mock(return_value=2) + transformer = BatchTransformer(collate_fn=collate, transforms=transforms, + auto_convert=False) + + output = transformer(0) + + collate.assert_called_once_with(0) + if isinstance(collate_output, Sequence): + transforms.assert_called_once_with(*collate_output) + elif isinstance(collate_output, Mapping): + transforms.assert_called_once_with(**collate_output) + else: + transforms.assert_called_once_with(collate_output) + self.assertEqual(2, output) + + def test_batch_transformer(self): + self.check_batch_transformer(0) + + def test_batch_transformer_sequence(self): + self.check_batch_transformer((0, 1)) + + def test_batch_transformer_mapping(self): + self.check_batch_transformer({"a": 0}) + + def test_batch_transformer_auto_convert(self): + collate = Mock(return_value=0) + transforms = Mock(return_value=np.array([0, 1])) + transformer = BatchTransformer(collate_fn=collate, transforms=transforms, + auto_convert=True) + output = transformer(0) + self.assertTrue((output == torch.tensor([0, 1])).all()) + + def test_dataloader_np_import_error(self): + with patch.dict('sys.modules', {'numpy': None}): + loader = DataLoader([0, 1, 2], num_workers=2) + iterator = iter(loader) + self.assertIsInstance(iterator, _MultiProcessingDataLoaderIter) + + def test_dataloader_single_process(self): + loader = DataLoader([0, 1, 2]) + iterator = iter(loader) + self.assertIsInstance(iterator, _SingleProcessDataLoaderIter) + + def test_dataloader_multi_process(self): + loader = DataLoader([0, 1, 2], num_workers=2) + iterator = iter(loader) + self.assertIsInstance(iterator, _MultiProcessingDataLoaderIter) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_transforms/test_functional/__init__.py b/tests/ops/__init__.py similarity index 100% rename from tests/test_transforms/test_functional/__init__.py rename to tests/ops/__init__.py diff --git a/tests/test_ops/test_tensor.py b/tests/ops/test_tensor.py similarity index 100% rename from tests/test_ops/test_tensor.py rename to tests/ops/test_tensor.py diff --git a/tests/test_transforms/__init__.py b/tests/transforms/__init__.py similarity index 100% rename from tests/test_transforms/__init__.py rename to tests/transforms/__init__.py diff --git a/tests/test_transforms/_utils.py b/tests/transforms/_utils.py similarity index 100% rename from tests/test_transforms/_utils.py rename to tests/transforms/_utils.py diff --git a/tests/transforms/functional/__init__.py b/tests/transforms/functional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_transforms/test_functional/test_crop.py b/tests/transforms/functional/test_crop.py similarity index 100% rename from tests/test_transforms/test_functional/test_crop.py rename to tests/transforms/functional/test_crop.py diff --git a/tests/test_transforms/test_functional/test_intensity.py b/tests/transforms/functional/test_intensity.py similarity index 100% rename from tests/test_transforms/test_functional/test_intensity.py rename to tests/transforms/functional/test_intensity.py diff --git a/tests/test_transforms/test_functional/test_spatial.py b/tests/transforms/functional/test_spatial.py similarity index 100% rename from tests/test_transforms/test_functional/test_spatial.py rename to tests/transforms/functional/test_spatial.py diff --git a/tests/test_transforms/test_abstract_transform.py b/tests/transforms/test_abstract_transform.py similarity index 100% rename from tests/test_transforms/test_abstract_transform.py rename to tests/transforms/test_abstract_transform.py diff --git a/tests/test_transforms/test_compose.py b/tests/transforms/test_compose.py similarity index 100% rename from tests/test_transforms/test_compose.py rename to tests/transforms/test_compose.py diff --git a/tests/test_transforms/test_crop.py b/tests/transforms/test_crop.py similarity index 100% rename from tests/test_transforms/test_crop.py rename to tests/transforms/test_crop.py diff --git a/tests/test_transforms/test_format_transforms.py b/tests/transforms/test_format_transforms.py similarity index 100% rename from tests/test_transforms/test_format_transforms.py rename to tests/transforms/test_format_transforms.py diff --git a/tests/test_transforms/test_intensity_transforms.py b/tests/transforms/test_intensity_transforms.py similarity index 99% rename from tests/test_transforms/test_intensity_transforms.py rename to tests/transforms/test_intensity_transforms.py index 66e5026d..7063d7cc 100644 --- a/tests/test_transforms/test_intensity_transforms.py +++ b/tests/transforms/test_intensity_transforms.py @@ -4,7 +4,7 @@ from math import isclose from unittest.mock import Mock, call -from tests.test_transforms import chech_data_preservation +from tests.transforms import chech_data_preservation from rising.transforms.intensity import * diff --git a/tests/test_transforms/test_kernel_transforms.py b/tests/transforms/test_kernel_transforms.py similarity index 97% rename from tests/test_transforms/test_kernel_transforms.py rename to tests/transforms/test_kernel_transforms.py index 28c8150a..61c7c74c 100644 --- a/tests/test_transforms/test_kernel_transforms.py +++ b/tests/transforms/test_kernel_transforms.py @@ -40,7 +40,6 @@ def test_gaussian_smoothing_transform(self): dim=2, stride=1, padding=1) self.batch_dict["data"][0, 0, 1] = 1 outp = trafo(**self.batch_dict) - print(outp["data"].shape) if __name__ == '__main__': diff --git a/tests/test_transforms/test_spatial_transforms.py b/tests/transforms/test_spatial_transforms.py similarity index 98% rename from tests/test_transforms/test_spatial_transforms.py rename to tests/transforms/test_spatial_transforms.py index 40d11ae7..8d6613fc 100644 --- a/tests/test_transforms/test_spatial_transforms.py +++ b/tests/transforms/test_spatial_transforms.py @@ -2,7 +2,7 @@ import random import unittest -from tests.test_transforms import chech_data_preservation +from tests.transforms import chech_data_preservation from rising.transforms.spatial import * from rising.transforms.functional.spatial import resize diff --git a/tests/test_transforms/test_utility_transforms.py b/tests/transforms/test_utility_transforms.py similarity index 100% rename from tests/test_transforms/test_utility_transforms.py rename to tests/transforms/test_utility_transforms.py