From e35a0d9873115a010b54fec71b68ec826048c1d0 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:00:59 +0100 Subject: [PATCH 01/39] Add loading package --- rising/loading/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 rising/loading/__init__.py diff --git a/rising/loading/__init__.py b/rising/loading/__init__.py new file mode 100644 index 00000000..e69de29b From 088bbe332421eb38843b5a4f77716c0bd8bf5812 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:01:11 +0100 Subject: [PATCH 02/39] add collate fn --- rising/loading/collate.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 rising/loading/collate.py diff --git a/rising/loading/collate.py b/rising/loading/collate.py new file mode 100644 index 00000000..152933d4 --- /dev/null +++ b/rising/loading/collate.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +import collections.abc +from typing import Any + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + + +def numpy_collate(batch: Any) -> Any: + """ + function to collate the samples to a whole batch of numpy arrays. + PyTorch Tensors, scalar values and sequences will be casted to arrays + automatically. + + Parameters + ---------- + batch : Any + a batch of samples. In most cases this is either a sequence, + a mapping or a mixture of them + + Returns + ------- + Any + collated batch with optionally converted type (to numpy array) + + """ + elem = batch[0] + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, torch.Tensor): + return numpy_collate([b.detach().cpu().numpy() for b in batch]) + elif isinstance(elem, float) or isinstance(elem, int): + return np.array(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: numpy_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return type(elem)(*(numpy_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + transposed = zip(*batch) + return [numpy_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(type(elem))) \ No newline at end of file From 540af3ab2d0568cc84c73938a3d11124c4b50306 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:07:28 +0100 Subject: [PATCH 03/39] Update setup.py --- setup.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 9a4636e8..5f613f7d 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from setuptools import setup, find_packages import versioneer import os -import re def resolve_requirements(file): @@ -27,7 +26,6 @@ def read_file(file): 'requirements.txt')) readme = read_file(os.path.join(os.path.dirname(__file__), "README.md")) -license = read_file(os.path.join(os.path.dirname(__file__), "LICENSE")) setup( @@ -35,14 +33,15 @@ def read_file(file): version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), packages=find_packages(), - # url='path/to/url', + url='https://github.com/phoenixdl/rising', test_suite="unittest", long_description=readme, long_description_content_type='text/markdown', install_requires=requirements, tests_require=["coverage"], python_requires=">=3.7", - author="Michael Baumgartner", - author_email="michael.baumgartner@rwth-aachen.de", - license=license, + author="PhoenixDL", + maintainer='Michael Baumgartner, Justus Schock', + maintainer_email='{michael.baumgartner, justus.schock}@rwth-aachen.de', + license='MIT', ) From 7cae604365b3c1f97a83bf829d2dff52e013ea72 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:07:42 +0100 Subject: [PATCH 04/39] update requirements --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 3b7480fb..fafb49ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ numpy torch +threadpoolctl +pandas From e880e1616bc069449f9b6d3982717c2474cfcbdf Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:39:39 +0100 Subject: [PATCH 05/39] add datasets --- rising/loading/dataset.py | 367 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 367 insertions(+) create mode 100644 rising/loading/dataset.py diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py new file mode 100644 index 00000000..7e497ef1 --- /dev/null +++ b/rising/loading/dataset.py @@ -0,0 +1,367 @@ +import os +import typing +import pathlib +from functools import partial +from tqdm import tqdm +from __future__ import annotations +import warnings + +from torch.utils.data import Dataset as TorchDset +from rising.loading.debug_mode import get_current_debug_mode +from torch.multiprocessing import Pool + + + +class Dataset(TorchDset): + """ + Extension of PyTorch's Datasets by a ``get_subset`` method which returns a + sub-dataset. + """ + + # TODO: Add return type signature + def get_subset(self, indices: typing.Sequence[int]) -> SubsetDataset: + """ + Returns a Subset of the current dataset based on given indices + + Parameters + ---------- + indices : iterable + valid indices to extract subset from current dataset + + Returns + ------- + :class:`SubsetDataset` + the subset + """ + + # extract other important attributes from current dataset + kwargs = {} + + for key, val in vars(self).items(): + if not (key.startswith("__") and key.endswith("__")): + + if key == "data": + continue + kwargs[key] = val + + kwargs["old_getitem"] = self.__class__.__getitem__ + subset_data = [self[idx] for idx in indices] + + return SubsetDataset(subset_data, **kwargs) + + +# NOTE: For backward compatibility (should be removed ASAP) +AbstractDataset = Dataset + + +class SubsetDataset(Dataset): + """ + A Dataset loading the data, which has been passed + in it's ``__init__`` by it's ``_sample_fn`` + """ + + def __init__(self, data: typing.Sequence, old_getitem: typing.Callable, + **kwargs): + """ + Parameters + ---------- + data : sequence + data to load (subset of original data) + old_getitem : function + get item method of previous dataset + **kwargs : + additional keyword arguments (are set as class attribute) + """ + super().__init__(None, None) + + self.data = data + self._old_getitem = old_getitem + + for key, val in kwargs.items(): + setattr(self, key, val) + + def __getitem__(self, index: int) -> typing.Union[typing.Dict, typing.Any]: + """ + returns single sample corresponding to ``index`` via the old get_item + Parameters + ---------- + index : int + index specifying the data to load + + Returns + ------- + Any, dict + can be any object containing a single sample, + but is often a dict-like. + """ + return self._old_getitem(self, index) + + def __len__(self) -> int: + """ + returns the length of the dataset + + Returns + ------- + int + number of samples + """ + return len(self.data) + + +class CacheDataset(Dataset): + def __init__(self, + data_path: typing.Union[typing.Union[pathlib.Path, + str], + list], + load_fn: typing.Callable, + mode: str = "append", + num_workers: int = None, + verbose=False, + **load_kwargs): + """ + A dataset to preload all the data and cache it for the entire + lifetime of this class. + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + mode : str + whether to append the sample to a list or to extend the list by + it. Supported modes are: :param:`append` and :param:`extend`. + Default: ``append`` + num_workers : int, optional + the number of workers to use for preloading. ``0`` means, all the + data will be loaded in the main process, while ``None`` means, + the number of processes will default to the number of logical + cores. + verbose : bool + whether to show the loading progress. Mutually exclusive with + ``num_workers is not None and num_workers > 0`` + **load_kwargs : + additional keyword arguments. Passed directly to :param:`load_fn` + """ + super().__init__() + + if (get_current_debug_mode() and + (num_workers is None or num_workers > 0)): + warnings.warn("The debug mode has been activated. " + "Falling back to num_workers = 0") + num_workers = 0 + + if (num_workers is None or num_workers > 0) and verbose: + warnings.warn("Verbosity is mutually exclusive with " + "num_workers > 0. Setting it to False instead.", + UserWarning) + verbose = False + + self._num_workers = num_workers + self._verbosity = verbose + + self._load_fn = load_fn + self._load_kwargs = load_kwargs + self.data = self._make_dataset(data_path, mode) + + def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], + list], + mode: str) -> typing.List[dict]: + """ + Function to build the entire dataset + + Parameters + ---------- + path : str, Path or list + the path(s) containing the data samples + mode : str + whether to append or extend the dataset by the loaded sample + + Returns + ------- + list + the loaded data + + """ + data = [] + if not isinstance(path, list): + assert os.path.isdir(path), '%s is not a valid directory' % path + path = [os.path.join(path, p) for p in os.listdir(path)] + + # sort for reproducibility (this is done explicitly since the listdir + # function does not return the paths in an ordered way on all OS) + path = sorted(path) + + # add loading kwargs + load_fn = partial(self._load_fn, **self._load_kwargs) + + # multiprocessing dispatch + if self._num_workers is None or self._num_workers > 0: + with Pool() as p: + _data = p.map(load_fn, path) + else: + if self._verbosity: + path = tqdm(path, unit='samples', desc="Loading Samples") + _data = map(load_fn, path) + + for sample in _data: + self._add_item(data, sample, mode) + return data + + @staticmethod + def _add_item(data: list, item: typing.Any, mode: str) -> None: + """ + Adds items to the given data list. The actual way of adding these + items depends on :param:`mode` + + Parameters + ---------- + data : list + the list containing the already loaded data + item : Any + the current item which will be added to the list + mode : str + the string specifying the mode of how the item should be added. + + """ + _mode = mode.lower() + + if _mode == 'append': + data.append(item) + elif _mode == 'extend': + data.extend(item) + else: + raise TypeError(f"Unknown mode detected: {mode} not supported.") + + def __getitem__(self, index: int) -> typing.Union[typing.Any, typing.Dict]: + """ + Making the whole Dataset indexeable. + + Parameters + ---------- + index : int + the integer specifying which sample to return + + Returns + ------- + Any, Dict + can be any object containing a single sample, but in practice is + often a dict + + """ + return self.data[index] + + +class LazyDataset(Dataset): + def __init__(self, data_path: typing.Union[str, list], + load_fn: typing.Callable, + **load_kwargs): + """ + A dataset to load all the data just in time. + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + load_kwargs: + additional keyword arguments (passed to :param:`load_fn`) + """ + super().__init__() + self._load_fn = load_fn + self._load_kwargs = load_kwargs + self.data = self._make_dataset(data_path) + + def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], + list]) -> typing.List[dict]: + """ + Function to build the entire dataset + + Parameters + ---------- + path : str, Path or list + the path(s) containing the data samples + + Returns + ------- + list + the loaded data + + """ + if not isinstance(path, list): + assert os.path.isdir(path), '%s is not a valid directory' % path + path = [os.path.join(path, p) for p in os.listdir(path)] + + sorted(path) + return path + + def __getitem__(self, index: int) -> dict: + """ + Making the whole Dataset indexeable. Loads the necessary sample. + + Parameters + ---------- + index : int + the integer specifying which sample to load and return + + Returns + ------- + Any, Dict + can be any object containing a single sample, but in practice is + often a dict + + """ + data_dict = self._load_fn(self.data[index], + **self._load_kwargs) + return data_dict + + +# TODO: Document IDManagers +class IDManager: + def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): + """ + Helper class to add additional functionality to Datasets + """ + self.id_key = id_key + self._cached_ids = None + + if cache_ids: + self.cache_ids() + + def cache_ids(self): + self._cached_ids = { + sample[self.id_key]: idx for idx, sample in enumerate(self)} + + def _find_index_iterative(self, id: str) -> int: + for idx, sample in enumerate(self): + if sample[self.id_key] == id: + return idx + raise KeyError(f"ID {id} not found.") + + def get_sample_by_id(self, id: str) -> dict: + return self[self.get_index_by_id(id)] + + def get_index_by_id(self, id: str) -> int: + if self._cached_ids is not None: + return self._cached_ids[id] + else: + return self._find_index_iterative(id) + + +class CacheDatasetID(CacheDataset, IDManager): + def __init__(self, data_path, load_fn, id_key, cache_ids=True, + **kwargs): + super().__init__(data_path, load_fn, **kwargs) + # check if AbstractDataset did not call IDManager with super + if not hasattr(self, "id_key"): + IDManager.__init__(self, id_key, cache_ids=cache_ids) + + +class LazyDatasetID(LazyDataset, IDManager): + def __init__(self, data_path, load_fn, id_key, cache_ids=True, + **kwargs): + super().__init__(data_path, load_fn, **kwargs) + # check if AbstractDataset did not call IDManager with super + if not hasattr(self, "id_key"): + IDManager.__init__(self, id_key, cache_ids=cache_ids) \ No newline at end of file From c08ef2ad161583752adb09f371ba195e26b6bfc8 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:39:51 +0100 Subject: [PATCH 06/39] add debug mode --- rising/loading/debug_mode.py | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 rising/loading/debug_mode.py diff --git a/rising/loading/debug_mode.py b/rising/loading/debug_mode.py new file mode 100644 index 00000000..d0028c07 --- /dev/null +++ b/rising/loading/debug_mode.py @@ -0,0 +1,40 @@ +__DEBUG_MODE = False + +# Functions to get and set the internal __DEBUG_MODE variable. This variable +# currently only defines whether to use multiprocessing or not. At the moment +# this is only used inside the DataManager, which either returns a +# MultiThreadedAugmenter or a SingleThreadedAugmenter depending on the current +# debug mode. +# All other functions using multiprocessing should be aware of this and +# implement a functionality without multiprocessing +# (even if this slows down things a lot!). + + +def get_current_debug_mode(): + """ + Getter function for the current debug mode + Returns + ------- + bool + current debug mode + """ + return __DEBUG_MODE + + +def switch_debug_mode(): + """ + Alternates the current debug mode + """ + set_debug_mode(not get_current_debug_mode()) + + +def set_debug_mode(mode: bool): + """ + Sets a new debug mode + Parameters + ---------- + mode : bool + the new debug mode + """ + global __DEBUG_MODE + __DEBUG_MODE = mode \ No newline at end of file From 00c5730a92cb5220c4cbc819df5a4f163141960c Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:51:55 +0100 Subject: [PATCH 07/39] add data loader --- rising/loading/loader.py | 240 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 rising/loading/loader.py diff --git a/rising/loading/loader.py b/rising/loading/loader.py new file mode 100644 index 00000000..24cdec2a --- /dev/null +++ b/rising/loading/loader.py @@ -0,0 +1,240 @@ +from typing import Callable, Mapping, Sequence, Union, Any +from torch.utils.data._utils.collate import default_convert +from torch.utils.data import DataLoader as _DataLoader, Sampler +from torch.utils.data.dataloader import \ + _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter as \ + __MultiProcessingDataLoaderIter +from rising.loading.debug_mode import get_current_debug_mode +from functools import partial +from rising.loading.dataset import Dataset +from threadpoolctl import threadpool_limits +from __future__ import annotations + + +class DataLoader(_DataLoader): + + def __init__(self, dataset: Union[Sequence, Dataset], + batch_size: int = 1, shuffle: bool = False, + batch_transforms: Callable = None, sampler: Sampler = None, + batch_sampler: Sampler = None, num_workers: int = 0, + collate_fn: Callable = None, + pin_memory: bool = False, drop_last: bool = False, + timeout: Union[int, float] = 0, + worker_init_fn: Callable = None, + multiprocessing_context=None, + auto_convert: bool = True): + + """ + A Dataloader introducing batch-transforms, numpy seeds for worker + processes and compatibility to the debug mode + + Note + ---- + For Reproducibility numpy and pytorch must be seeded in the main + process, as these frameworks will be used to generate their own seeds + for each worker. + + Note + ---- + ``len(dataloader)`` heuristic is based on the length of the sampler + used. When :attr:`dataset` is an + :class:`~torch.utils.data.IterableDataset`, an infinite sampler is + used, whose :meth:`__len__` is not implemented, because the actual + length depends on both the iterable as well as multi-process loading + configurations. So one should not query this method unless they work + with a map-style dataset. See `Dataset Types`_ for more details on + these two types of datasets. + + Warning + ------- + If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + Parameters + ---------- + dataset : Dataset + dataset from which to load the data + batch_size : int, optional + how many samples per batch to load (default: ``1``). + shuffle : bool, optional + set to ``True`` to have the data reshuffled at every epoch + (default: ``False``) + batch_transforms : callable, optional + transforms which can be applied to a whole batch. + Usually this accepts either mappings or sequences and returns the + same type containing transformed elements + sampler : torch.utils.data.Sampler, optional + defines the strategy to draw samples from + the dataset. If specified, :attr:`shuffle` must be ``False``. + batch_sampler : torch.utils.data.Sampler, optional + like :attr:`sampler`, but returns a batch of + indices at a time. Mutually exclusive with :attr:`batch_size`, + :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. + num_workers : int, optional + how many subprocesses to use for data loading. + ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn : callable, optional + merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory : bool, optional + If ``True``, the data loader will copy Tensors + into CUDA pinned memory before returning them. If your data + elements are a custom type, or your :attr:`collate_fn` returns a + batch that is a custom type, see the example below. + drop_last : bool, optional + set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. + If ``False`` and the size of dataset is not divisible by the batch + size, then the last batch will be smaller. (default: ``False``) + timeout : numeric, optional + if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn : callable, optional + If not ``None``, this will be called on each + worker subprocess with the worker id + (an int in ``[0, num_workers - 1]``) as input, after seeding and + before data loading. (default: ``None``) + auto_convert : bool, optional + if set to ``True``, the batches will always be transformed to + torch.Tensors, if possible. (default: ``True``) + """ + + super().__init__(dataset=dataset, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, + collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, + worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context) + + self.collate_fn = BatchTransformer(self.collate_fn, batch_transforms, + auto_convert) + + def __iter__(self) -> Union[_SingleProcessDataLoaderIter, + _MultiProcessingDataLoaderIter]: + if self.num_workers == 0 or get_current_debug_mode(): + return _SingleProcessDataLoaderIter(self) + else: + return _MultiProcessingDataLoaderIter(self) + + +class BatchTransformer(object): + """ + A callable wrapping the collate_fn to enable transformations on a + batch-basis. + """ + + def __init__(self, collate_fn: Callable, transforms: Callable = None, + auto_convert: bool = True): + """ + Parameters + ---------- + collate_fn : callable, optional + merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + transforms : callable, optional + transforms which can be applied to a whole batch. + Usually this accepts either mappings or sequences and returns the + same type containing transformed elements + auto_convert : bool, optional + if set to ``True``, the batches will always be transformed to + torch.Tensors, if possible. (default: ``True``) + """ + + self._collate_fn = collate_fn + self._transforms = transforms + self._auto_convert = auto_convert + + def __call__(self, *args, **kwargs) -> Any: + batch = self._collate_fn(*args, **kwargs) + + if self._transforms is not None: + + if isinstance(batch, Mapping): + batch = self._transforms(**batch) + elif isinstance(batch, Sequence): + batch = self._transforms(*batch) + else: + batch = self._transforms(batch) + + if self._auto_convert: + batch = default_convert(batch) + + return batch + + +class _MultiProcessingDataLoaderIter(__MultiProcessingDataLoaderIter): + # NOTE [ Numpy Seeds ] + # This class is a subclass of + # ``torch.utils.data.dataloader._MultiProcessingDataLoaderIter``` and only + # adds some additional logic to provide differnt seeds for numpy in + # each worker. These seeds are based on a base seed, which itself get's + # generated by numpy. So to ensure reproducibility, numpy must be seeded + # in the main process. + def __init__(self, loader): + """ + Iterator over Dataloader. Handles the complete multiprocessing + + Parameters + ---------- + loader : DataLoader + the dataloader instance to iterate over + """ + + try: + import numpy as np + npy_seed = np.random.randint(0, + (2 ** 32) - (1 + loader.num_workers)) + except ImportError: + # we don't generate a numpy seed here with torch, since we don't + # need one; if the import fails in the main process it should + # also fail in child processes + npy_seed = None + + old_worker_init = loader.worker_init_fn + + new_worker_init_fn = partial(_seed_npy_before_worker_init, + seed=npy_seed, + worker_init_fn=old_worker_init) + loader.worker_init_fn = new_worker_init_fn + + with threadpool_limits(limits=1, user_api='blas'): + super().__init__(loader) + + # reset worker_init_fn once the workers have been startet to reset + # to original state for next epoch + loader.worker_init_fn = old_worker_init + + +def _seed_npy_before_worker_init(worker_id: int, seed: int, + worker_init_fn: Callable = None): + """ + Wrapper Function to wrap the existing worker_init_fn and seed numpy before + calling the actual ``worker_init_fn`` + + Parameters + ---------- + worker_id : int + the number of the worker + seed : int32 + the base seed in a range of [0, 2**32 - (1 + ``num_workers``)]. + The range ensures, that the whole seed, which consists of the base + seed and the ``worker_id``, can still be represented as a unit32, + as it needs to be for numpy seeding + worker_init_fn : callable, optional + will be called with the ``worker_id`` after seeding numpy if it is not + ``None`` + """ + try: + import numpy as np + np.random.seed(seed + worker_id) + except ImportError: + pass + + if worker_init_fn is not None: + return worker_init_fn(worker_id) From f76f6bceb0619b42d3357ce560760db0ee0fc27e Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:54:54 +0100 Subject: [PATCH 08/39] add containers --- rising/loading/container.py | 112 ++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 rising/loading/container.py diff --git a/rising/loading/container.py b/rising/loading/container.py new file mode 100644 index 00000000..8d2157ac --- /dev/null +++ b/rising/loading/container.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import pandas as pd +import typing +import pathlib +from collections import defaultdict + +from rising.loading.dataset import Dataset +from rising.loading.splitter import SplitType + + +# TODO: Add docstrings for Datacontainer +class DataContainer: + def __init__(self, dataset: Dataset, **kwargs): + self._dataset = dataset + self._dset = {} + self._fold = None + super().__init__(**kwargs) + + def split_by_index(self, split: SplitType): + for key, idx in split.items(): + self._dset[key] = self._dataset.get_subset(idx) + + def kfold_by_index(self, splits: typing.Iterable[SplitType]): + for fold, split in enumerate(splits): + self.split_by_index(split) + self._fold = fold + yield self + self._fold = None + + def split_by_csv(self, path: typing.Union[pathlib.Path, str], + index_column: str, **kwargs): + df = pd.read_csv(path, **kwargs) + df = df.set_index(index_column) + col = list(df.columns) + self.split_by_index(self._read_split_from_df(df, col[0])) + + def kfold_by_csv(self, path: typing.Union[pathlib.Path, str], + index_column: str, **kwargs) -> DataContainer: + df = pd.read_csv(path, **kwargs) + df = df.set_index(index_column) + folds = list(df.columns) + splits = [self._read_split_from_df(df, fold) for fold in folds] + yield from self.kfold_by_index((splits)) + + @staticmethod + def _read_split_from_df(df: pd.DataFrame, col: str): + split = defaultdict(list) + for index, row in df[[col]].iterrows(): + split[str(row[col])].append(index) + return split + + @property + def dset(self) -> Dataset: + if not self._dset: + raise AttributeError("No Split found.") + else: + return self._dset + + @property + def fold(self) -> int: + if self._fold is None: + raise AttributeError( + "Fold not specified. Call `kfold_by_index` first.") + else: + return self._fold + + +class DataContainerID(DataContainer): + def split_by_id(self, split: SplitType): + split_idx = defaultdict(list) + for key, _id in split.items(): + for _i in _id: + split_idx[key].append(self._dataset.get_index_by_id(_i)) + super().split_by_index(split_idx) + + def kfold_by_id( + self, + splits: typing.Iterable[SplitType]) -> DataContainerID: + for fold, split in enumerate(splits): + self.split_by_id(split) + self._fold = fold + yield self + self._fold = None + + def split_by_csv_id(self, path: typing.Union[pathlib.Path, str], + id_column: str, **kwargs): + df = pd.read_csv(path, **kwargs) + df = df.set_index(id_column) + col = list(df.columns) + self.split_by_id(self._read_split_from_df(df, col[0])) + + def kfold_by_csv_id(self, path: typing.Union[pathlib.Path, str], + id_column: str, **kwargs) -> DataContainerID: + df = pd.read_csv(path, **kwargs) + df = df.set_index(id_column) + folds = list(df.columns) + splits = [self._read_split_from_df(df, fold) for fold in folds] + yield from self.kfold_by_id((splits)) + + def save_split_to_csv_id(self, + path: typing.Union[pathlib.Path, + str], + id_key: str, + split_column: str = 'split', + **kwargs): + split_dict = {str(id_key): [], str(split_column): []} + for key, item in self._dset.items(): + for sample in item: + split_dict[str(id_key)].append(sample[id_key]) + split_dict[str(split_column)].append(str(key)) + pd.DataFrame(split_dict).to_csv(path, **kwargs) \ No newline at end of file From ee1b2a4b6cf99c7c7895d2382db63994f37c9bc4 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:55:05 +0100 Subject: [PATCH 09/39] splitter --- rising/loading/splitter.py | 218 +++++++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 rising/loading/splitter.py diff --git a/rising/loading/splitter.py b/rising/loading/splitter.py new file mode 100644 index 00000000..9197f5e4 --- /dev/null +++ b/rising/loading/splitter.py @@ -0,0 +1,218 @@ +import copy +import typing +import logging +import math +from sklearn.model_selection import train_test_split, GroupShuffleSplit, \ + KFold, GroupKFold, StratifiedKFold + +from rising.loading.dataset import Dataset + +logger = logging.getLogger(__file__) + +SplitType = typing.Dict[str, list] + + +# TODO: Add docstrings for Splitter +class Splitter: + def __init__(self, + dataset: Dataset, + val_size: typing.Union[int, float], + test_size: typing.Union[int, float] = None): + super().__init__() + self._dataset = dataset + self._total_num = len(self._dataset) + self._idx = list(range(self._total_num)) + self._val = val_size + self._test = test_size if test_size is not None else 0 + + self._check_sizes() + + def _check_sizes(self): + if self._total_num < 0: + raise TypeError("Size must be larger than zero, not " + "{}".format(self._total_num)) + if self._val < 0: + raise TypeError("Size must be larger than zero, not " + "{}".format(self._val)) + if self._test < 0: + raise TypeError("Size must be larger than zero, not " + "{}".format(self._test)) + + self._convert_prop_to_num() + if self._total_num < self._val + self._test: + raise TypeError("Val + test size must be smaller than total, " + "not {}".format(self._val + self._test)) + + def index_split(self, **kwargs) -> SplitType: + split_dict = {} + split_dict["train"], tmp = train_test_split( + self._idx, test_size=self._val + self._test, + **kwargs) + + if self._test > 0: + split_dict["val"], split_dict["test"] = train_test_split( + tmp, test_size=self._val, **kwargs) + else: + split_dict["val"] = tmp + self.log_split(split_dict, "Created Single Split with:") + return split_dict + + def index_split_stratified( + self, + stratify_key: str = "label", + **kwargs) -> SplitType: + split_dict = {} + stratify = [d[stratify_key] for d in self._dataset] + + split_dict["train"], tmp = train_test_split( + self._idx, test_size=self._val + self._test, stratify=stratify, **kwargs) + + if self._test > 0: + stratify_tmp = [stratify[_i] for _i in tmp] + split_dict["val"], split_dict["test"] = train_test_split( + tmp, test_size=self._val, stratify=stratify_tmp, **kwargs) + else: + split_dict["val"] = tmp + self.log_split(split_dict, "Created Single Split with:") + return split_dict + + def index_split_grouped( + self, + groups_key: str = "id", + **kwargs) -> SplitType: + """ + ..warning:: Shuffling cannot be deactivated + """ + split_dict = {} + groups = [d[groups_key] for d in self._dataset] + + gsp = GroupShuffleSplit( + n_splits=1, test_size=self._val + self._test, **kwargs) + split_dict["train"], tmp = next(gsp.split(self._idx, groups=groups)) + + if self._test > 0: + groups_tmp = [groups[_i] for _i in tmp] + gsp = GroupShuffleSplit(n_splits=1, test_size=self._val, **kwargs) + split_dict["val"], split_dict["test"] = next( + gsp.split(tmp, groups=groups_tmp)) + else: + split_dict["val"] = tmp + self.log_split(split_dict, "Created Single Split with:") + return split_dict + + def index_kfold_fixed_test(self, **kwargs) -> typing.Iterable[SplitType]: + splits = [] + + idx_dict = self.index_split(**kwargs) + train_val_idx = idx_dict.pop("train") + idx_dict.pop("val") + + logger.info("Creating {} folds.".format(self.val_folds)) + kf = KFold(n_splits=self.val_folds, **kwargs) + _fold = 0 + for train_idx, val_idx in kf.split(train_val_idx): + splits.append(self._copy_and_fill_dict( + idx_dict, train=train_idx, val=val_idx)) + self.log_split(splits[-1], f"Created Fold{_fold}.") + _fold += 1 + return splits + + def index_kfold_fixed_test_stratified( + self, + stratify_key: str = "label", + **kwargs) -> typing.Iterable[SplitType]: + splits = [] + + idx_dict = self.index_split_stratified(**kwargs) + train_val_idx = idx_dict.pop("train") + idx_dict.pop("val") + train_val_stratify = [ + self._dataset[_i][stratify_key] for _i in train_val_idx] + + logger.info("Creating {} folds.".format(self.val_fols)) + kf = StratifiedKFold(n_splits=self.val_folds, **kwargs) + _fold = 0 + for train_idx, val_idx in kf.split(train_val_idx, train_val_stratify): + splits.append(self._copy_and_fill_dict( + idx_dict, train=train_idx, val=val_idx)) + self.log_split(splits[-1], f"Created Fold{_fold}.") + _fold += 1 + return splits + + def index_kfold_fixed_test_grouped(self, groups_key: str = "id", + **kwargs) -> typing.Iterable[SplitType]: + splits = [] + + idx_dict = self.index_split_grouped(**kwargs) + train_val_idx = idx_dict.pop("train") + idx_dict.pop("val") + train_val_groups = [ + self._dataset[_i][groups_key] for _i in train_val_idx] + + logger.info("Creating {} folds.".format(self.val_fols)) + kf = GroupKFold(n_splits=self.val_folds, **kwargs) + _fold = 0 + for train_idx, val_idx in kf.split( + train_val_idx, groups=train_val_groups): + splits.append(self._copy_and_fill_dict( + idx_dict, train=train_idx, val=val_idx)) + self.log_split(splits[-1], f"Created Fold{_fold}.") + _fold += 1 + return splits + + def _convert_prop_to_num(self, attributes: tuple = ("_val", "_test")): + for attr in attributes: + value = getattr(self, attr) + if value < 1 and math.isclose(value, 0): + setattr(self, attr, value * self._total_num) + + @staticmethod + def log_split(dict_like: dict, desc: str = None): + if desc is not None: + logger.info(desc) + for key, item in dict_like.items(): + logger.info(f"{str(key).upper()} contains {len(item)} indices.") + + @staticmethod + def _copy_and_fill_dict(dict_like: dict, **kwargs) -> dict: + new_dict = copy.deepcopy(dict_like) + for key, item in kwargs.items(): + new_dict[key] = item + return new_dict + + @property + def dataset(self) -> Dataset: + return self._dataset + + @dataset.setter + def dataset(self, dset: Dataset): + self._dataset = dset + self._total_num = len(self._dataset) + self._idx = list(range(self._total_num)) + + @property + def val_size(self) -> int: + return self._val + + @val_size.setter + def val_size(self, value: typing.Union[int, float]): + self._val = value + self._check_sizes() + + @property + def test_size(self) -> int: + return self._test + + @test_size.setter + def test_size(self, value: typing.Union[int, float]): + self._test = value + self._check_sizes() + + @property + def folds(self) -> int: + return self.val_folds * self.test_folds + + @property + def val_folds(self) -> int: + return int(self._total_num // self._val) + + @property + def test_folds(self) -> int: + return int(self._total_num // self._test) \ No newline at end of file From 7f749c9b47abf7b00b43dc339b72b78ef91b6217 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 17:55:18 +0100 Subject: [PATCH 10/39] update requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index fafb49ad..771408bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy torch threadpoolctl pandas +sklearn \ No newline at end of file From adf282abc3a0738d2d4ae8cc15669b0eec79bc8d Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:03:18 +0100 Subject: [PATCH 11/39] rename test directories --- tests/ops/__init__.py | 0 tests/{test_ops => ops}/test_tensor.py | 0 tests/{test_transforms => transforms}/__init__.py | 0 tests/{test_transforms => transforms}/_utils.py | 0 .../{test_transforms => transforms}/test_abstract_transform.py | 0 tests/{test_transforms => transforms}/test_format_transforms.py | 0 tests/transforms/test_functional/__init__.py | 0 .../test_functional/test_intensity.py | 0 .../test_functional/test_spatial.py | 0 .../test_intensity_transforms.py | 2 +- tests/{test_transforms => transforms}/test_kernel_transforms.py | 0 .../{test_transforms => transforms}/test_spatial_transforms.py | 2 +- .../{test_transforms => transforms}/test_utility_transforms.py | 0 13 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 tests/ops/__init__.py rename tests/{test_ops => ops}/test_tensor.py (100%) rename tests/{test_transforms => transforms}/__init__.py (100%) rename tests/{test_transforms => transforms}/_utils.py (100%) rename tests/{test_transforms => transforms}/test_abstract_transform.py (100%) rename tests/{test_transforms => transforms}/test_format_transforms.py (100%) create mode 100644 tests/transforms/test_functional/__init__.py rename tests/{test_transforms => transforms}/test_functional/test_intensity.py (100%) rename tests/{test_transforms => transforms}/test_functional/test_spatial.py (100%) rename tests/{test_transforms => transforms}/test_intensity_transforms.py (99%) rename tests/{test_transforms => transforms}/test_kernel_transforms.py (100%) rename tests/{test_transforms => transforms}/test_spatial_transforms.py (96%) rename tests/{test_transforms => transforms}/test_utility_transforms.py (100%) diff --git a/tests/ops/__init__.py b/tests/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_ops/test_tensor.py b/tests/ops/test_tensor.py similarity index 100% rename from tests/test_ops/test_tensor.py rename to tests/ops/test_tensor.py diff --git a/tests/test_transforms/__init__.py b/tests/transforms/__init__.py similarity index 100% rename from tests/test_transforms/__init__.py rename to tests/transforms/__init__.py diff --git a/tests/test_transforms/_utils.py b/tests/transforms/_utils.py similarity index 100% rename from tests/test_transforms/_utils.py rename to tests/transforms/_utils.py diff --git a/tests/test_transforms/test_abstract_transform.py b/tests/transforms/test_abstract_transform.py similarity index 100% rename from tests/test_transforms/test_abstract_transform.py rename to tests/transforms/test_abstract_transform.py diff --git a/tests/test_transforms/test_format_transforms.py b/tests/transforms/test_format_transforms.py similarity index 100% rename from tests/test_transforms/test_format_transforms.py rename to tests/transforms/test_format_transforms.py diff --git a/tests/transforms/test_functional/__init__.py b/tests/transforms/test_functional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_transforms/test_functional/test_intensity.py b/tests/transforms/test_functional/test_intensity.py similarity index 100% rename from tests/test_transforms/test_functional/test_intensity.py rename to tests/transforms/test_functional/test_intensity.py diff --git a/tests/test_transforms/test_functional/test_spatial.py b/tests/transforms/test_functional/test_spatial.py similarity index 100% rename from tests/test_transforms/test_functional/test_spatial.py rename to tests/transforms/test_functional/test_spatial.py diff --git a/tests/test_transforms/test_intensity_transforms.py b/tests/transforms/test_intensity_transforms.py similarity index 99% rename from tests/test_transforms/test_intensity_transforms.py rename to tests/transforms/test_intensity_transforms.py index 66e5026d..7063d7cc 100644 --- a/tests/test_transforms/test_intensity_transforms.py +++ b/tests/transforms/test_intensity_transforms.py @@ -4,7 +4,7 @@ from math import isclose from unittest.mock import Mock, call -from tests.test_transforms import chech_data_preservation +from tests.transforms import chech_data_preservation from rising.transforms.intensity import * diff --git a/tests/test_transforms/test_kernel_transforms.py b/tests/transforms/test_kernel_transforms.py similarity index 100% rename from tests/test_transforms/test_kernel_transforms.py rename to tests/transforms/test_kernel_transforms.py diff --git a/tests/test_transforms/test_spatial_transforms.py b/tests/transforms/test_spatial_transforms.py similarity index 96% rename from tests/test_transforms/test_spatial_transforms.py rename to tests/transforms/test_spatial_transforms.py index 60c9489a..1b5737e9 100644 --- a/tests/test_transforms/test_spatial_transforms.py +++ b/tests/transforms/test_spatial_transforms.py @@ -2,7 +2,7 @@ import random import unittest -from tests.test_transforms import chech_data_preservation +from tests.transforms import chech_data_preservation from rising.transforms.spatial import * diff --git a/tests/test_transforms/test_utility_transforms.py b/tests/transforms/test_utility_transforms.py similarity index 100% rename from tests/test_transforms/test_utility_transforms.py rename to tests/transforms/test_utility_transforms.py From 4416632fed82f39d56548863ca1decd0a9e4999e Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:03:34 +0100 Subject: [PATCH 12/39] add _utils for testing --- tests/_utils.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/_utils.py diff --git a/tests/_utils.py b/tests/_utils.py new file mode 100644 index 00000000..ffa342a0 --- /dev/null +++ b/tests/_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +from rising.loading.dataset import CacheDataset, CacheDatasetID + + +class LoadDummySample: + def __init__(self, keys=('data', 'label'), sizes=((3, 128, 128), (3,)), + **kwargs): + super().__init__(**kwargs) + self.keys = keys + self.sizes = sizes + + def __call__(self, path, *args, **kwargs): + data = {_k: np.random.rand(*_s) + for _k, _s in zip(self.keys, self.sizes)} + data['id'] = f'sample{path}' + return data + + +class DummyDataset(CacheDataset): + def __init__(self, num_samples=10, load_fn=LoadDummySample(), + **load_kwargs): + super().__init__(list(range(num_samples)), load_fn, **load_kwargs) + + +class DummyDatasetID(CacheDatasetID): + def __init__(self, num_samples=10, load_fn=LoadDummySample(), + **load_kwargs): + super().__init__(list(range(num_samples)), load_fn, id_key="id", + **load_kwargs) From ca7fde2288ae9ec493ea580d52c2ac2e0c808d73 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:04:01 +0100 Subject: [PATCH 13/39] rename test directory --- .../test_functional => transforms/functional}/__init__.py | 0 .../transforms/{test_functional => functional}/test_intensity.py | 0 tests/transforms/{test_functional => functional}/test_spatial.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{test_transforms/test_functional => transforms/functional}/__init__.py (100%) rename tests/transforms/{test_functional => functional}/test_intensity.py (100%) rename tests/transforms/{test_functional => functional}/test_spatial.py (100%) diff --git a/tests/test_transforms/test_functional/__init__.py b/tests/transforms/functional/__init__.py similarity index 100% rename from tests/test_transforms/test_functional/__init__.py rename to tests/transforms/functional/__init__.py diff --git a/tests/transforms/test_functional/test_intensity.py b/tests/transforms/functional/test_intensity.py similarity index 100% rename from tests/transforms/test_functional/test_intensity.py rename to tests/transforms/functional/test_intensity.py diff --git a/tests/transforms/test_functional/test_spatial.py b/tests/transforms/functional/test_spatial.py similarity index 100% rename from tests/transforms/test_functional/test_spatial.py rename to tests/transforms/functional/test_spatial.py From d8e35c457aae1179096c4449de21e7e460191dae Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:41:44 +0100 Subject: [PATCH 14/39] pep8 changes --- rising/loading/dataset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 7e497ef1..3980ca59 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -11,14 +11,11 @@ from torch.multiprocessing import Pool - class Dataset(TorchDset): """ Extension of PyTorch's Datasets by a ``get_subset`` method which returns a sub-dataset. """ - - # TODO: Add return type signature def get_subset(self, indices: typing.Sequence[int]) -> SubsetDataset: """ Returns a Subset of the current dataset based on given indices From 1599a2848f810c1d4ed3eb495bcbc4af59b66fca Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:42:27 +0100 Subject: [PATCH 15/39] add dummy testcase for numpy_collate --- tests/loading/test_collate.py | 37 +++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/loading/test_collate.py diff --git a/tests/loading/test_collate.py b/tests/loading/test_collate.py new file mode 100644 index 00000000..cdf47ee5 --- /dev/null +++ b/tests/loading/test_collate.py @@ -0,0 +1,37 @@ +import torch +import unittest + +try: + import numpy as np +except ImportError: + np = None + +from rising.loading.collate import numpy_collate + + +# TODO: Add more collate test cases +class TestCollate(unittest.TestCase): + @unittest.skipIf(np is None, 'numpy is not available') + def test_default_collate_dtype(self): + arr = [1, 2, -1] + collated = numpy_collate(arr) + self.assertEqual(collated, np.array(arr)) + self.assertEqual(collated.dtype, np.int32) + + arr = [1.1, 2.3, -0.9] + collated = numpy_collate(arr) + self.assertEqual(collated, np.array(arr)) + self.assertEqual(collated.dtype, np.float32) + + arr = [True, False] + collated = numpy_collate(arr) + self.assertEqual(collated, np.array(arr)) + self.assertEqual(collated.dtype, np.bool) + + # Should be a no-op + arr = ['a', 'b', 'c'] + self.assertEqual(arr, numpy_collate(arr)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From f1096edcf76ae553f7c332228990a3ca02ebae6b Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:43:05 +0100 Subject: [PATCH 16/39] add container tests --- tests/loading/test_container.py | 131 ++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 tests/loading/test_container.py diff --git a/tests/loading/test_container.py b/tests/loading/test_container.py new file mode 100644 index 00000000..9259918b --- /dev/null +++ b/tests/loading/test_container.py @@ -0,0 +1,131 @@ +import unittest +import os +from pathlib import Path +import numpy as np +import pandas as pd +from rising.loading.container import DataContainer, DataContainerID +from tests import DummyDataset, DummyDatasetID + + +class LoadDummySampleID: + def __init__(self, keys=('data', 'label'), sizes=((3, 128, 128), (3,)), + **kwargs): + super().__init__(**kwargs) + self.keys = keys + self.sizes = sizes + + def __call__(self, path, *args, **kwargs): + data = {_k: np.random.rand(*_s) + for _k, _s in zip(self.keys, self.sizes)} + data['id'] = int(path) + return data + + +class TestDataContainer(unittest.TestCase): + def setUp(self): + self.dset = DummyDataset(num_samples=6, + load_fn=(LoadDummySampleID())) + self.dset_id = DummyDatasetID(num_samples=6, + load_fn=LoadDummySampleID()) + self.split = {"train": [0, 1, 2], "val": [3, 4, 5]} + self.kfold = [{"train": [0, 1, 2], "val": [3, 4, 5]}, + {"val": [0, 1, 2], "train": [3, 4, 5]}] + + def test_empty_container(self): + container = DataContainer(self.dset) + with self.assertRaises(AttributeError): + container.dset["train"] + + with self.assertRaises(AttributeError): + container.fold + + def test_split_by_index(self): + container = DataContainer(self.dset) + container.split_by_index(self.split) + self.check_split(container) + + def check_split(self, container): + self._assert_split(container, [0, 1, 2], [3, 4, 5]) + + def test_kfold_by_index(self): + container = DataContainer(self.dset) + self.check_kfold(container.kfold_by_index(self.kfold)) + + def check_kfold(self, container_generator): + for container_fold in container_generator: + if container_fold.fold == 0: + self._assert_split(container_fold, [0, 1, 2], [3, 4, 5]) + elif container_fold.fold == 1: + self._assert_split(container_fold, [3, 4, 5], [0, 1, 2]) + else: + self.assertTrue(False, "Unknown Fold") + + def test_split_by_csv(self): + container = DataContainer(self.dset) + p = os.path.join(os.path.dirname(__file__), '_src', 'split.csv') + container.split_by_csv(p, 'index', sep=';') + self.check_split_csv(container) + + def check_split_csv(self, container): + self._assert_split(container, [0, 1], [2, 3], [4, 5]) + + def test_kfold_by_csv(self): + container = DataContainer(self.dset) + p = os.path.join(os.path.dirname(__file__), '_src', 'kfold.csv') + self.check_kfold_csv(container.kfold_by_csv(p, 'index', sep=';')) + + def check_kfold_csv(self, container_generator): + for container_fold in container_generator: + if container_fold.fold == 0: + self._assert_split(container_fold, [0, 1], [2, 3], [4, 5]) + elif container_fold.fold == 1: + self._assert_split(container_fold, [2, 3], [4, 5], [0, 1]) + elif container_fold.fold == 2: + self._assert_split(container_fold, [4, 5], [0, 1], [2, 3]) + else: + self.assertTrue(False, "Unknown Fold") + + def _assert_split(self, container, train, val=None, test=None): + self.assertEqual([d["id"] for d in container.dset["train"]], train) + if val is not None: + self.assertEqual([d["id"] for d in container.dset["val"]], val) + if test is not None: + self.assertEqual([d["id"] for d in container.dset["test"]], test) + + def test_split_by_index_id(self): + container = DataContainerID(self.dset_id) + container.split_by_id(self.split) + self.check_split(container) + + def test_kfold_by_index_id(self): + container = DataContainerID(self.dset_id) + self.check_kfold(container.kfold_by_id(self.kfold)) + + def test_split_by_csv_id(self): + container = DataContainerID(self.dset_id) + p = os.path.join(os.path.dirname(__file__), '_src', 'split.csv') + container.split_by_csv_id(p, 'index', sep=';') + self.check_split_csv(container) + + def test_kfold_by_csv_id(self): + container = DataContainerID(self.dset_id) + p = os.path.join(os.path.dirname(__file__), '_src', 'kfold.csv') + self.check_kfold_csv(container.kfold_by_csv_id(p, 'index', sep=';')) + + def test_save_split_to_csv_id(self): + container = DataContainerID(self.dset_id) + container.split_by_id(self.split) + p = os.path.join(os.path.dirname(__file__), + "_src", "test_generated_split.csv") + + container.save_split_to_csv_id(p, "id") + + df = pd.read_csv(p) + self.assertTrue(df["id"].to_list(), [1, 2, 3, 4, 5, 6]) + self.assertTrue(df["split"].to_list(), + ["train", "train", "train", "val", "val", "val"]) + os.remove(p) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 197c4d756d68c71f915dd2d87cd5f88b71113f68 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:43:17 +0100 Subject: [PATCH 17/39] add dataset tests --- tests/loading/test_dataset.py | 134 ++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tests/loading/test_dataset.py diff --git a/tests/loading/test_dataset.py b/tests/loading/test_dataset.py new file mode 100644 index 00000000..76363d6e --- /dev/null +++ b/tests/loading/test_dataset.py @@ -0,0 +1,134 @@ +import unittest + +import numpy as np +from rising.loading.dataset import CacheDataset, LazyDataset, CacheDatasetID, \ + LazyDatasetID + + +# TODO: Additional Tests for subsetdataset + +class LoadDummySample: + def __call__(self, path, *args, **kwargs): + data = {'data': np.random.rand(1, 256, 256), + 'label': np.random.randint(2), + 'id': f"sample{path}"} + return data + + +class TestBaseDataset(unittest.TestCase): + def setUp(self): + self.paths = list(range(10)) + + def test_cache_dataset(self): + dataset = CacheDataset(self.paths, LoadDummySample(), + label_load_fct=None) + self.assertEqual(len(dataset), 10) + self.check_dataset_access(dataset, [0, 5, 9]) + self.check_dataset_outside_access(dataset, [10, 20]) + self.check_dataset_iter(dataset) + + def test_cache_dataset_extend(self): + def load_mul_sample(path) -> list: + return [LoadDummySample()(path, None)] * 4 + + dataset = CacheDataset(self.paths, load_mul_sample, + mode='extend') + self.assertEqual(len(dataset), 40) + self.check_dataset_access(dataset, [0, 20, 39]) + self.check_dataset_outside_access(dataset, [40, 45]) + self.check_dataset_iter(dataset) + + def test_lazy_dataset(self): + dataset = LazyDataset(self.paths, LoadDummySample(), + label_load_fct=None) + self.assertEqual(len(dataset), 10) + self.check_dataset_access(dataset, [0, 5, 9]) + self.check_dataset_outside_access(dataset, [10, 20]) + self.check_dataset_iter(dataset) + + def check_dataset_access(self, dataset, inside_idx): + try: + for _i in inside_idx: + a = dataset[_i] + except BaseException: + self.assertTrue(False) + + def check_dataset_outside_access(self, dataset, outside_idx): + for _i in outside_idx: + with self.assertRaises(IndexError): + a = dataset[_i] + + def check_dataset_iter(self, dataset): + try: + j = 0 + for i in dataset: + self.assertIn('data', i) + self.assertIn('label', i) + j += 1 + assert j == len(dataset) + except BaseException: + raise AssertionError('Dataset iteration failed.') + + +class TestDatasetID(unittest.TestCase): + def test_load_dummy_sample(self): + load_fn = LoadDummySample() + sample0 = load_fn(None, None) + self.assertIn("data", sample0) + self.assertIn("label", sample0) + self.assertTrue("id", "sample0") + + sample1 = load_fn(None, None) + self.assertIn("data", sample1) + self.assertIn("label", sample1) + self.assertTrue("id", "sample1") + + def check_dataset( + self, + dset_cls, + num_samples, + expected_len, + debug_num, + **kwargs): + load_fn = LoadDummySample() + dset = dset_cls(list(range(num_samples)), load_fn, + debug_num=debug_num, **kwargs) + self.assertEqual(len(dset), expected_len) + + def test_base_cache_dataset(self): + self.check_dataset(CacheDataset, num_samples=20, + expected_len=20, debug_num=10) + + def test_base_lazy_dataset_debug_off(self): + self.check_dataset(LazyDataset, num_samples=20, + expected_len=20, debug_num=10) + + def test_cachedataset_id(self): + load_fn = LoadDummySample() + dset = CacheDatasetID(list(range(10)), load_fn, + id_key="id", cache_ids=False) + self.check_dset_id(dset) + + def test_lazydataset_id(self): + load_fn = LoadDummySample() + dset = LazyDatasetID(list(range(10)), load_fn, + id_key="id", cache_ids=False) + self.check_dset_id(dset) + + def check_dset_id(self, dset): + idx = dset.get_index_by_id("sample1") + self.assertTrue(idx, 1) + + with self.assertRaises(KeyError): + idx = dset.get_index_by_id("sample10") + + sample5 = dset.get_sample_by_id("sample5") + self.assertTrue(sample5["id"], 5) + + dset.cache_ids() + sample6 = dset.get_sample_by_id("sample6") + self.assertTrue(sample6["id"], 6) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From a046704b95093ebf80435ab62a6c1ae9472048ac Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:43:47 +0100 Subject: [PATCH 18/39] Add dummy files --- tests/{test_ops => loading}/__init__.py | 0 tests/loading/test_debug_mode.py | 1 + tests/loading/test_loader.py | 1 + tests/transforms/test_functional/__init__.py | 0 4 files changed, 2 insertions(+) rename tests/{test_ops => loading}/__init__.py (100%) create mode 100644 tests/loading/test_debug_mode.py create mode 100644 tests/loading/test_loader.py delete mode 100644 tests/transforms/test_functional/__init__.py diff --git a/tests/test_ops/__init__.py b/tests/loading/__init__.py similarity index 100% rename from tests/test_ops/__init__.py rename to tests/loading/__init__.py diff --git a/tests/loading/test_debug_mode.py b/tests/loading/test_debug_mode.py new file mode 100644 index 00000000..8051fd04 --- /dev/null +++ b/tests/loading/test_debug_mode.py @@ -0,0 +1 @@ +# TODO: Add tests for debug mode \ No newline at end of file diff --git a/tests/loading/test_loader.py b/tests/loading/test_loader.py new file mode 100644 index 00000000..a71ee5bc --- /dev/null +++ b/tests/loading/test_loader.py @@ -0,0 +1 @@ +# TODO: Add Loader Tests \ No newline at end of file diff --git a/tests/transforms/test_functional/__init__.py b/tests/transforms/test_functional/__init__.py deleted file mode 100644 index e69de29b..00000000 From 05d1e3a42f1a5dab8cde1e65c937abc83311a90d Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:44:55 +0100 Subject: [PATCH 19/39] Export Basic API (should we export more?) --- rising/loading/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rising/loading/__init__.py b/rising/loading/__init__.py index e69de29b..44a3deaa 100644 --- a/rising/loading/__init__.py +++ b/rising/loading/__init__.py @@ -0,0 +1,3 @@ +from rising.loading.collate import numpy_collate +from rising.loading.dataset import Dataset +from rising.loading.loader import DataLoader From 5943a186f393eea081b1f6704e02d3ce95a56d2b Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Tue, 26 Nov 2019 17:46:37 +0000 Subject: [PATCH 20/39] autopep8 fix --- rising/loading/dataset.py | 3 ++- rising/loading/loader.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 3980ca59..00c57682 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -16,6 +16,7 @@ class Dataset(TorchDset): Extension of PyTorch's Datasets by a ``get_subset`` method which returns a sub-dataset. """ + def get_subset(self, indices: typing.Sequence[int]) -> SubsetDataset: """ Returns a Subset of the current dataset based on given indices @@ -361,4 +362,4 @@ def __init__(self, data_path, load_fn, id_key, cache_ids=True, super().__init__(data_path, load_fn, **kwargs) # check if AbstractDataset did not call IDManager with super if not hasattr(self, "id_key"): - IDManager.__init__(self, id_key, cache_ids=cache_ids) \ No newline at end of file + IDManager.__init__(self, id_key, cache_ids=cache_ids) diff --git a/rising/loading/loader.py b/rising/loading/loader.py index 24cdec2a..0a289cb1 100644 --- a/rising/loading/loader.py +++ b/rising/loading/loader.py @@ -23,7 +23,6 @@ def __init__(self, dataset: Union[Sequence, Dataset], worker_init_fn: Callable = None, multiprocessing_context=None, auto_convert: bool = True): - """ A Dataloader introducing batch-transforms, numpy seeds for worker processes and compatibility to the debug mode From 6d3e5d96abddf70288f43a15da619dfddca76a48 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Tue, 26 Nov 2019 18:48:11 +0100 Subject: [PATCH 21/39] Add imports in tests package --- tests/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..ac3ff41f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +from ._utils import LoadDummySample, DummyDataset, DummyDatasetID \ No newline at end of file From b95ee1c07dea7efeae68609ddc9b04c03e7f5423 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 11:14:58 +0100 Subject: [PATCH 22/39] document id managers --- rising/loading/dataset.py | 103 +++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 00c57682..838693af 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -315,11 +315,23 @@ def __getitem__(self, index: int) -> dict: return data_dict -# TODO: Document IDManagers +# TODO: Maybe we should add the dataset baseclass as baseclass of this as well +# (since it should just extend it and still have all the other dataset +# functionalities)? + class IDManager: def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): """ Helper class to add additional functionality to Datasets + + Parameters + ---------- + id_key : str + the id key to cache + cache_ids : bool + whether to cache the ids + **kwargs : + additional keyword arguments """ self.id_key = id_key self._cached_ids = None @@ -327,20 +339,71 @@ def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): if cache_ids: self.cache_ids() - def cache_ids(self): + def cache_ids(self) -> None: + """ + Caches the IDs + + """ self._cached_ids = { sample[self.id_key]: idx for idx, sample in enumerate(self)} def _find_index_iterative(self, id: str) -> int: + """ + Checks for the next index matching the given id + + Parameters + ---------- + id : str + the id to get the index for + + Returns + ------- + int + the returned index + + Raises + ------ + KeyError + no index matching the given id + + """ for idx, sample in enumerate(self): if sample[self.id_key] == id: return idx raise KeyError(f"ID {id} not found.") def get_sample_by_id(self, id: str) -> dict: + """ + Fetches the sample to a corresponding ID + + Parameters + ---------- + id : str + the id specifying the sample to return + + Returns + ------- + dict + the sample corresponding to the given ID + + """ return self[self.get_index_by_id(id)] def get_index_by_id(self, id: str) -> int: + """ + Returns the index corresponding to a given id + + Parameters + ---------- + id : str + the id specifying the index of which sample should be returned + + Returns + ------- + int + the index of the sample matching the given id + + """ if self._cached_ids is not None: return self._cached_ids[id] else: @@ -350,6 +413,24 @@ def get_index_by_id(self, id: str) -> int: class CacheDatasetID(CacheDataset, IDManager): def __init__(self, data_path, load_fn, id_key, cache_ids=True, **kwargs): + """ + Caching version of ID Dataset + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + id_key : str + the id key to cache + cache_ids : bool + whether to cache the ids + **kwargs : + additional keyword arguments + """ + # TODO: Shouldn't we call the baseclasses explicitly here? with super + # it is not clear, which baseclass is actually called super().__init__(data_path, load_fn, **kwargs) # check if AbstractDataset did not call IDManager with super if not hasattr(self, "id_key"): @@ -359,6 +440,24 @@ def __init__(self, data_path, load_fn, id_key, cache_ids=True, class LazyDatasetID(LazyDataset, IDManager): def __init__(self, data_path, load_fn, id_key, cache_ids=True, **kwargs): + """ + Lazy version of ID Dataset + + Parameters + ---------- + data_path : str, Path or list + the path(s) containing the actual data samples + load_fn : function + function to load the actual data + id_key : str + the id key to cache + cache_ids : bool + whether to cache the ids + **kwargs : + additional keyword arguments + """ + # TODO: Shouldn't we call the baseclasses explicitly here? with super + # it is not clear, which baseclass is actually called super().__init__(data_path, load_fn, **kwargs) # check if AbstractDataset did not call IDManager with super if not hasattr(self, "id_key"): From ba5017efb8b2217d389a5d163a6e9a68353efa8b Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 11:15:40 +0100 Subject: [PATCH 23/39] Move future imports to top of file --- rising/loading/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 838693af..5cb71a03 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import os import typing import pathlib from functools import partial from tqdm import tqdm -from __future__ import annotations import warnings from torch.utils.data import Dataset as TorchDset From 50153e7a398290e87baf66cb0c29659266340f52 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 11:57:48 +0100 Subject: [PATCH 24/39] Add docstrings, todos and comments for splitter --- rising/loading/splitter.py | 242 ++++++++++++++++++++++++++++++++----- 1 file changed, 213 insertions(+), 29 deletions(-) diff --git a/rising/loading/splitter.py b/rising/loading/splitter.py index 9197f5e4..efe02a7f 100644 --- a/rising/loading/splitter.py +++ b/rising/loading/splitter.py @@ -12,12 +12,36 @@ SplitType = typing.Dict[str, list] -# TODO: Add docstrings for Splitter +# TODO: I would probably change this and make val_size optionally. +# We always need a testset, but not always a validationset class Splitter: def __init__(self, dataset: Dataset, val_size: typing.Union[int, float], test_size: typing.Union[int, float] = None): + """ + Splits a dataset by several options + + Parameters + ---------- + dataset : Dataset + the dataset to split + val_size : float, int + the validation split; + if float this will be interpreted as a percentage of the + dataset + if int this will be interpreted as the number of samples + test_size : float, int , optionally + the size of the validation split; If provided it must be int or + float. + if float this will be interpreted as a percentage of the + dataset + if int this will be interpreted as the number of samples + if not provided or explicitly set to None, no testset will be + created + """ + # TODO: Since we only have object as implicit baseclass the + # super().__init__() can probably be removed super().__init__() self._dataset = dataset self._total_num = len(self._dataset) @@ -28,28 +52,62 @@ def __init__(self, self._check_sizes() def _check_sizes(self): - if self._total_num < 0: - raise TypeError("Size must be larger than zero, not " - "{}".format(self._total_num)) - if self._val < 0: - raise TypeError("Size must be larger than zero, not " - "{}".format(self._val)) - if self._test < 0: - raise TypeError("Size must be larger than zero, not " - "{}".format(self._test)) + """ + Checks if the given sizes are valid for splitting + + Raises + ------ + ValueError + at least one of the sizes is invalid + """ + if self._total_num <= 0: + raise ValueError("Size must be larger than zero, not " + "{}".format(self._total_num)) + if self._val <= 0: + raise ValueError("Size must be larger than zero, not " + "{}".format(self._val)) + if self._test < 0: + raise ValueError("Size must be larger than zero, not " + "{}".format(self._test)) + + # TODO: Can we explicitly call this check in the __init__ before + # checking the sizes? Explicit is better than implicit + # When I first checked the Code I was wondering where this was done + # and I could not find it, since this function should only check + # and not convert anything self._convert_prop_to_num() if self._total_num < self._val + self._test: - raise TypeError("Val + test size must be smaller than total, " - "not {}".format(self._val + self._test)) + raise ValueError("Val + test size must be smaller than total, " + "not {}".format(self._val + self._test)) def index_split(self, **kwargs) -> SplitType: + """ + Splits the dataset's indices in a random way + + Parameters + ---------- + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + dict + the dictionary containing the corresponding splits under the + keys 'train', 'val' and (optionally) 'test' + + """ split_dict = {} split_dict["train"], tmp = train_test_split( self._idx, test_size=self._val + self._test, **kwargs) if self._test > 0: + # update stratified if provided, + # necessary for index_split_stratified + if 'stratify' in kwargs: + kwargs['stratify'] = [kwargs['stratify'][_i] for _i in tmp] split_dict["val"], split_dict["test"] = train_test_split( tmp, test_size=self._val, **kwargs) else: @@ -61,28 +119,59 @@ def index_split_stratified( self, stratify_key: str = "label", **kwargs) -> SplitType: - split_dict = {} - stratify = [d[stratify_key] for d in self._dataset] + """ + Splits the dataset's indices in a stratified way + + Parameters + ---------- + stratify_key : str + the key specifying which value of each sample to use for + stratification + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + dict + the dictionary containing the corresponding splits under the + keys 'train', 'val' and (optionally) 'test' - split_dict["train"], tmp = train_test_split( - self._idx, test_size=self._val + self._test, stratify=stratify, **kwargs) + """ + stratify = [d[stratify_key] for d in self._dataset] - if self._test > 0: - stratify_tmp = [stratify[_i] for _i in tmp] - split_dict["val"], split_dict["test"] = train_test_split( - tmp, test_size=self._val, stratify=stratify_tmp, **kwargs) - else: - split_dict["val"] = tmp - self.log_split(split_dict, "Created Single Split with:") - return split_dict + return self.index_split(stratify=stratify, **kwargs) def index_split_grouped( self, groups_key: str = "id", **kwargs) -> SplitType: """ - ..warning:: Shuffling cannot be deactivated + Splits the dataset's indices in a stratified way + + Parameters + ---------- + groups_key : str + the key specifying which value of each sample to use for + grouping + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + dict + the dictionary containing the corresponding splits under the + keys 'train', 'val' and (optionally) 'test' + + Warnings + -------- + Shuffling cannot be deactivated """ + # TODO: maybe we should implement a single split function, which + # handles random, stratificated and grouped splitting? This would not + # be hard at all (based on the first look at sklearn internals) and + # would remove some code duolication in here split_dict = {} groups = [d[groups_key] for d in self._dataset] @@ -100,7 +189,24 @@ def index_split_grouped( self.log_split(split_dict, "Created Single Split with:") return split_dict + # TODO: Maybe add kfolds without fixed testset? def index_kfold_fixed_test(self, **kwargs) -> typing.Iterable[SplitType]: + """ + Calculates splits for a random kfold with given testset + + Parameters + ---------- + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + list + list containing one dict for each fold each containing the + corresponding splits under the keys 'train', 'val' and 'test' + + """ splits = [] idx_dict = self.index_split(**kwargs) @@ -120,6 +226,25 @@ def index_kfold_fixed_test_stratified( self, stratify_key: str = "label", **kwargs) -> typing.Iterable[SplitType]: + """ + Calculates splits for a stratified kfold with given testset + + Parameters + ---------- + stratify_key : str + the key specifying which value of each sample to use for + stratification + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + list + list containing one dict for each fold each containing the + corresponding splits under the keys 'train', 'val' and 'test' + + """ splits = [] idx_dict = self.index_split_stratified(**kwargs) @@ -139,6 +264,25 @@ def index_kfold_fixed_test_stratified( def index_kfold_fixed_test_grouped(self, groups_key: str = "id", **kwargs) -> typing.Iterable[SplitType]: + """ + Calculates splits for a stratified kfold with given testset + + Parameters + ---------- + groups_key : str + the key specifying which value of each sample to use for + grouping + **kwargs : + optional keyword arguments. + See :func:`sklearn.model_selection.train_test_split` for details + + Returns + ------- + list + list containing one dict for each fold each containing the + corresponding splits under the keys 'train', 'val' and 'test' + + """ splits = [] idx_dict = self.index_split_grouped(**kwargs) @@ -157,14 +301,39 @@ def index_kfold_fixed_test_grouped(self, groups_key: str = "id", _fold += 1 return splits - def _convert_prop_to_num(self, attributes: tuple = ("_val", "_test")): + def _convert_prop_to_num(self, attributes: tuple = ("_val", "_test") + ) -> None: + """ + Converts all given attributes from percentages to number of samples + if necessary + + Parameters + ---------- + attributes : tuple + tuple of strings containing the attribute names + + """ for attr in attributes: value = getattr(self, attr) + # TODO: When is a value close to 0? + # Shouldn't we only check if 0<=value<1? if value < 1 and math.isclose(value, 0): setattr(self, attr, value * self._total_num) @staticmethod - def log_split(dict_like: dict, desc: str = None): + def log_split(dict_like: dict, desc: str = None) -> None: + """ + Logs the new created split + + Parameters + ---------- + dict_like : dict + the splits (usually this dict contains the keys 'train', 'val' + and (optionally) 'test' and a list of indices for each of them + desc : str, optional + the descriptor string to log before the actual splits + + """ if desc is not None: logger.info(desc) for key, item in dict_like.items(): @@ -172,9 +341,24 @@ def log_split(dict_like: dict, desc: str = None): @staticmethod def _copy_and_fill_dict(dict_like: dict, **kwargs) -> dict: + """ + copies the dict and adds the kwargs to the copy + + Parameters + ---------- + dict_like : dict + the dict to copy and fill + **kwargs : + the keyword argument added to the dict copy + + Returns + ------- + dict + the copied and filled dict + + """ new_dict = copy.deepcopy(dict_like) - for key, item in kwargs.items(): - new_dict[key] = item + new_dict.update(kwargs) return new_dict @property From 21af266d44a10c20811a3602929b143f63f97ce1 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 15:38:35 +0100 Subject: [PATCH 25/39] Add docstrings for data container --- rising/loading/container.py | 92 +++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/rising/loading/container.py b/rising/loading/container.py index 8d2157ac..ae1e22fb 100644 --- a/rising/loading/container.py +++ b/rising/loading/container.py @@ -12,16 +12,52 @@ # TODO: Add docstrings for Datacontainer class DataContainer: def __init__(self, dataset: Dataset, **kwargs): + """ + Handles the splitting of datasets from different sources + + Parameters + ---------- + dataset : dataset + the dataset to split + kwargs + """ self._dataset = dataset self._dset = {} self._fold = None + # TODO: this does not make sense, the constructor of object + # (which is the implicit base class here) does not take arguments. + # We should rather set them as attributes. super().__init__(**kwargs) - def split_by_index(self, split: SplitType): + def split_by_index(self, split: SplitType) -> None: + """ + Splits dataset by a given split-dict + + Parameters + ---------- + split : dict + a dictionary containing tuples of strings and lists of indices + for each split + + """ for key, idx in split.items(): self._dset[key] = self._dataset.get_subset(idx) def kfold_by_index(self, splits: typing.Iterable[SplitType]): + """ + Produces kfold splits based on the given indices. + + Parameters + ---------- + splits : list + list containing split dicts for each fold + + Yields + ------ + DataContainer + the data container with updated dataset splits + + """ for fold, split in enumerate(splits): self.split_by_index(split) self._fold = fold @@ -29,7 +65,21 @@ def kfold_by_index(self, splits: typing.Iterable[SplitType]): self._fold = None def split_by_csv(self, path: typing.Union[pathlib.Path, str], - index_column: str, **kwargs): + index_column: str, **kwargs) -> None: + """ + Splits a dataset by splits given in a CSV file + + Parameters + ---------- + path : str, pathlib.Path + the path to the csv file + index_column : str + the label of the index column + **kwargs : + additional keyword arguments (see :func:`pandas.read_csv` for + details) + + """ df = pd.read_csv(path, **kwargs) df = df.set_index(index_column) col = list(df.columns) @@ -37,6 +87,25 @@ def split_by_csv(self, path: typing.Union[pathlib.Path, str], def kfold_by_csv(self, path: typing.Union[pathlib.Path, str], index_column: str, **kwargs) -> DataContainer: + """ + Produces kfold splits based on the given csv file. + + Parameters + ---------- + path : str, pathlib.Path + the path to the csv file + index_column : str + the label of the index column + **kwargs : + additional keyword arguments (see :func:`pandas.read_csv` for + details) + + Yields + ------ + DataContainer + the data container with updated dataset splits + + """ df = pd.read_csv(path, **kwargs) df = df.set_index(index_column) folds = list(df.columns) @@ -44,7 +113,23 @@ def kfold_by_csv(self, path: typing.Union[pathlib.Path, str], yield from self.kfold_by_index((splits)) @staticmethod - def _read_split_from_df(df: pd.DataFrame, col: str): + def _read_split_from_df(df: pd.DataFrame, col: str) -> SplitType: + """ + Helper function to read a split from a given data frame + + Parameters + ---------- + df : pandas.DataFrame + the dataframe containing the split + col : str + the column inside the data frame containing the split + + Returns + ------- + dict + a dictionary of lists. Contains a string-list-tuple per split + + """ split = defaultdict(list) for index, row in df[[col]].iterrows(): split[str(row[col])].append(index) @@ -66,6 +151,7 @@ def fold(self) -> int: return self._fold +# TODO: Add Docstrings for datacontainerID class DataContainerID(DataContainer): def split_by_id(self, split: SplitType): split_idx = defaultdict(list) From 9325788b6f341db23098a8113aeb44a1f6cade00 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 15:48:49 +0100 Subject: [PATCH 26/39] add docstinrgs for DataContainerID --- rising/loading/container.py | 95 +++++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 10 deletions(-) diff --git a/rising/loading/container.py b/rising/loading/container.py index ae1e22fb..1f5bbac5 100644 --- a/rising/loading/container.py +++ b/rising/loading/container.py @@ -43,6 +43,8 @@ def split_by_index(self, split: SplitType) -> None: for key, idx in split.items(): self._dset[key] = self._dataset.get_subset(idx) + # TODO: Shouldn"t the kfold methods instead yield the current datasets + # instead of the whole cointainer? def kfold_by_index(self, splits: typing.Iterable[SplitType]): """ Produces kfold splits based on the given indices. @@ -151,18 +153,43 @@ def fold(self) -> int: return self._fold -# TODO: Add Docstrings for datacontainerID class DataContainerID(DataContainer): - def split_by_id(self, split: SplitType): + """ + Data Container Class for datasets with an ID + """ + def split_by_id(self, split: SplitType) -> None: + """ + Splits the internal dataset by the given splits + + Parameters + ---------- + split : dict + dictionary containing a string-list tuple per split + + """ split_idx = defaultdict(list) for key, _id in split.items(): for _i in _id: split_idx[key].append(self._dataset.get_index_by_id(_i)) - super().split_by_index(split_idx) + return super().split_by_index(split_idx) def kfold_by_id( self, - splits: typing.Iterable[SplitType]) -> DataContainerID: + splits: typing.Iterable[SplitType]): + """ + Produces kfold splits by an ID + + Parameters + ---------- + splits : list + list of dicts each containing the splits for a separate fold + + Yields + ------ + DataContaimnerID + the data container with updated internal datasets + + """ for fold, split in enumerate(splits): self.split_by_id(split) self._fold = fold @@ -170,14 +197,47 @@ def kfold_by_id( self._fold = None def split_by_csv_id(self, path: typing.Union[pathlib.Path, str], - id_column: str, **kwargs): + id_column: str, **kwargs) -> None: + """ + Splits the internal dataset by a given id column in a given csv file + + Parameters + ---------- + path : str or pathlib.Path + the path to the csv file + id_column : str + the key of the id_column + **kwargs : + additionalm keyword arguments (see :func:`pandas.read_csv` for + details) + + """ df = pd.read_csv(path, **kwargs) df = df.set_index(id_column) col = list(df.columns) - self.split_by_id(self._read_split_from_df(df, col[0])) + return self.split_by_id(self._read_split_from_df(df, col[0])) def kfold_by_csv_id(self, path: typing.Union[pathlib.Path, str], - id_column: str, **kwargs) -> DataContainerID: + id_column: str, **kwargs): + """ + Produces kfold splits by an ID column of a given csv file + + Parameters + ---------- + path : str or pathlib.Path + the path to the csv file + id_column : str + the key of the id_column + **kwargs : + additionalm keyword arguments (see :func:`pandas.read_csv` for + details) + + Yields + ------ + DataContaimnerID + the data container with updated internal datasets + + """ df = pd.read_csv(path, **kwargs) df = df.set_index(id_column) folds = list(df.columns) @@ -185,11 +245,26 @@ def kfold_by_csv_id(self, path: typing.Union[pathlib.Path, str], yield from self.kfold_by_id((splits)) def save_split_to_csv_id(self, - path: typing.Union[pathlib.Path, - str], + path: typing.Union[pathlib.Path, str], id_key: str, split_column: str = 'split', - **kwargs): + **kwargs) -> None: + """ + Saves a split top a given csv id + + Parameters + ---------- + path : str or pathlib.Path + the path of the csv file + id_key : str + the id key inside the csv file + split_column : str + the name of the split_column inside the csv file + **kwargs : + additional keyword arguments (see :meth:`pd.DataFrame.to_csv` + for details) + + """ split_dict = {str(id_key): [], str(split_column): []} for key, item in self._dset.items(): for sample in item: From 1bef306a3879ae8ceaffec65ffa9ee92416dda74 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 15:51:53 +0100 Subject: [PATCH 27/39] Add comment --- rising/loading/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rising/loading/loader.py b/rising/loading/loader.py index 0a289cb1..84d35e82 100644 --- a/rising/loading/loader.py +++ b/rising/loading/loader.py @@ -187,6 +187,9 @@ def __init__(self, loader): try: import numpy as np + # generate numpy seed. The range comes so that the seed in each + # worker (which is this baseseed plus the worker id) is always an + # uint32. This is because numpy only accepts uint32 as valid seeds npy_seed = np.random.randint(0, (2 ** 32) - (1 + loader.num_workers)) except ImportError: From 85946d9da162a3a18483fcacc3bc69f7bee217b3 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 28 Nov 2019 15:52:44 +0100 Subject: [PATCH 28/39] move future imports to top of file --- rising/loading/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rising/loading/loader.py b/rising/loading/loader.py index 84d35e82..27fe3350 100644 --- a/rising/loading/loader.py +++ b/rising/loading/loader.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Callable, Mapping, Sequence, Union, Any from torch.utils.data._utils.collate import default_convert from torch.utils.data import DataLoader as _DataLoader, Sampler @@ -8,7 +9,6 @@ from functools import partial from rising.loading.dataset import Dataset from threadpoolctl import threadpool_limits -from __future__ import annotations class DataLoader(_DataLoader): From 1008e2722659eb10a5d9a84cb823dc37b6ca5ce2 Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Thu, 28 Nov 2019 15:59:06 +0000 Subject: [PATCH 29/39] autopep8 fix --- rising/loading/container.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rising/loading/container.py b/rising/loading/container.py index 1f5bbac5..d73f388d 100644 --- a/rising/loading/container.py +++ b/rising/loading/container.py @@ -157,6 +157,7 @@ class DataContainerID(DataContainer): """ Data Container Class for datasets with an ID """ + def split_by_id(self, split: SplitType) -> None: """ Splits the internal dataset by the given splits @@ -270,4 +271,4 @@ def save_split_to_csv_id(self, for sample in item: split_dict[str(id_key)].append(sample[id_key]) split_dict[str(split_column)].append(str(key)) - pd.DataFrame(split_dict).to_csv(path, **kwargs) \ No newline at end of file + pd.DataFrame(split_dict).to_csv(path, **kwargs) From 5edf840054d0eddc32af8d0fd41ec8829fa7d911 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 1 Dec 2019 19:47:20 +0100 Subject: [PATCH 30/39] fix bugs and tests --- .codecov.yml | 1 + rising/loading/dataset.py | 60 +++++++++++++++++++-------------- tests/loading/_src/kfold.csv | 7 ++++ tests/loading/_src/split.csv | 7 ++++ tests/loading/test_collate.py | 22 ++++++++---- tests/loading/test_container.py | 6 ++-- tests/loading/test_dataset.py | 6 ++++ 7 files changed, 75 insertions(+), 34 deletions(-) create mode 100644 tests/loading/_src/kfold.csv create mode 100644 tests/loading/_src/split.csv diff --git a/.codecov.yml b/.codecov.yml index e5b9d64f..c6370468 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -11,4 +11,5 @@ coverage: ignore: - "tests/" - "*/__init.py" + - "_version.py" diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 5cb71a03..6f831f20 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -32,7 +32,6 @@ def get_subset(self, indices: typing.Sequence[int]) -> SubsetDataset: :class:`SubsetDataset` the subset """ - # extract other important attributes from current dataset kwargs = {} @@ -43,14 +42,10 @@ def get_subset(self, indices: typing.Sequence[int]) -> SubsetDataset: continue kwargs[key] = val - kwargs["old_getitem"] = self.__class__.__getitem__ + old_getitem = self.__class__.__getitem__ subset_data = [self[idx] for idx in indices] - return SubsetDataset(subset_data, **kwargs) - - -# NOTE: For backward compatibility (should be removed ASAP) -AbstractDataset = Dataset + return SubsetDataset(subset_data, old_getitem, **kwargs) class SubsetDataset(Dataset): @@ -71,7 +66,7 @@ def __init__(self, data: typing.Sequence, old_getitem: typing.Callable, **kwargs : additional keyword arguments (are set as class attribute) """ - super().__init__(None, None) + super().__init__() self.data = data self._old_getitem = old_getitem @@ -109,9 +104,7 @@ def __len__(self) -> int: class CacheDataset(Dataset): def __init__(self, - data_path: typing.Union[typing.Union[pathlib.Path, - str], - list], + data_path: typing.Union[typing.Union[pathlib.Path, str], list], load_fn: typing.Callable, mode: str = "append", num_workers: int = None, @@ -163,8 +156,7 @@ def __init__(self, self._load_kwargs = load_kwargs self.data = self._make_dataset(data_path, mode) - def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], - list], + def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], list], mode: str) -> typing.List[dict]: """ Function to build the entire dataset @@ -201,7 +193,7 @@ def _make_dataset(self, path: typing.Union[typing.Union[pathlib.Path, str], else: if self._verbosity: path = tqdm(path, unit='samples', desc="Loading Samples") - _data = map(load_fn, path) + _data = map(load_fn, path) for sample in _data: self._add_item(data, sample, mode) @@ -250,6 +242,17 @@ def __getitem__(self, index: int) -> typing.Union[typing.Any, typing.Dict]: """ return self.data[index] + def __len__(self) -> int: + """ + Length of dataset + + Returns + ------- + int + number of elements + """ + return len(self.data) + class LazyDataset(Dataset): def __init__(self, data_path: typing.Union[str, list], @@ -315,11 +318,21 @@ def __getitem__(self, index: int) -> dict: **self._load_kwargs) return data_dict + def __len__(self) -> int: + """ + Length of dataset + + Returns + ------- + int + number of elements + """ + return len(self.data) + # TODO: Maybe we should add the dataset baseclass as baseclass of this as well # (since it should just extend it and still have all the other dataset # functionalities)? - class IDManager: def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): """ @@ -334,6 +347,7 @@ def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): **kwargs : additional keyword arguments """ + super().__init__(**kwargs) self.id_key = id_key self._cached_ids = None @@ -411,7 +425,7 @@ def get_index_by_id(self, id: str) -> int: return self._find_index_iterative(id) -class CacheDatasetID(CacheDataset, IDManager): +class CacheDatasetID(IDManager, CacheDataset): def __init__(self, data_path, load_fn, id_key, cache_ids=True, **kwargs): """ @@ -432,13 +446,11 @@ def __init__(self, data_path, load_fn, id_key, cache_ids=True, """ # TODO: Shouldn't we call the baseclasses explicitly here? with super # it is not clear, which baseclass is actually called - super().__init__(data_path, load_fn, **kwargs) - # check if AbstractDataset did not call IDManager with super - if not hasattr(self, "id_key"): - IDManager.__init__(self, id_key, cache_ids=cache_ids) + super().__init__(data_path=data_path, load_fn=load_fn, id_key=id_key, + cache_ids=cache_ids, **kwargs) -class LazyDatasetID(LazyDataset, IDManager): +class LazyDatasetID(IDManager, LazyDataset): def __init__(self, data_path, load_fn, id_key, cache_ids=True, **kwargs): """ @@ -459,7 +471,5 @@ def __init__(self, data_path, load_fn, id_key, cache_ids=True, """ # TODO: Shouldn't we call the baseclasses explicitly here? with super # it is not clear, which baseclass is actually called - super().__init__(data_path, load_fn, **kwargs) - # check if AbstractDataset did not call IDManager with super - if not hasattr(self, "id_key"): - IDManager.__init__(self, id_key, cache_ids=cache_ids) + super().__init__(data_path=data_path, load_fn=load_fn, id_key=id_key, + cache_ids=cache_ids, **kwargs) diff --git a/tests/loading/_src/kfold.csv b/tests/loading/_src/kfold.csv new file mode 100644 index 00000000..8a6f1ae7 --- /dev/null +++ b/tests/loading/_src/kfold.csv @@ -0,0 +1,7 @@ +index;fold0;fold1;fold2 +0;train;test;val +1;train;test;val +2;val;train;test +3;val;train;test +4;test;val;train +5;test;val;train \ No newline at end of file diff --git a/tests/loading/_src/split.csv b/tests/loading/_src/split.csv new file mode 100644 index 00000000..b1f886f6 --- /dev/null +++ b/tests/loading/_src/split.csv @@ -0,0 +1,7 @@ +index;fold0 +0;train +1;train +2;val +3;val +4;test +5;test \ No newline at end of file diff --git a/tests/loading/test_collate.py b/tests/loading/test_collate.py index cdf47ee5..86e3cb22 100644 --- a/tests/loading/test_collate.py +++ b/tests/loading/test_collate.py @@ -12,25 +12,33 @@ # TODO: Add more collate test cases class TestCollate(unittest.TestCase): @unittest.skipIf(np is None, 'numpy is not available') - def test_default_collate_dtype(self): + def test_default_collate_int(self): arr = [1, 2, -1] collated = numpy_collate(arr) - self.assertEqual(collated, np.array(arr)) - self.assertEqual(collated.dtype, np.int32) + expected = np.array(arr) + self.assertTrue((collated == expected).all()) + self.assertEqual(collated.dtype, expected.dtype) + @unittest.skipIf(np is None, 'numpy is not available') + def test_default_collate_float(self): arr = [1.1, 2.3, -0.9] collated = numpy_collate(arr) - self.assertEqual(collated, np.array(arr)) - self.assertEqual(collated.dtype, np.float32) + expected = np.array(arr) + self.assertTrue((collated == expected).all()) + self.assertEqual(collated.dtype, expected.dtype) + @unittest.skipIf(np is None, 'numpy is not available') + def test_default_collate_bool(self): arr = [True, False] collated = numpy_collate(arr) - self.assertEqual(collated, np.array(arr)) + self.assertTrue(all(collated == np.array(arr))) self.assertEqual(collated.dtype, np.bool) + @unittest.skipIf(np is None, 'numpy is not available') + def test_default_collate_str(self): # Should be a no-op arr = ['a', 'b', 'c'] - self.assertEqual(arr, numpy_collate(arr)) + self.assertTrue((arr == numpy_collate(arr))) if __name__ == '__main__': diff --git a/tests/loading/test_container.py b/tests/loading/test_container.py index 9259918b..449a4bcc 100644 --- a/tests/loading/test_container.py +++ b/tests/loading/test_container.py @@ -24,9 +24,11 @@ def __call__(self, path, *args, **kwargs): class TestDataContainer(unittest.TestCase): def setUp(self): self.dset = DummyDataset(num_samples=6, - load_fn=(LoadDummySampleID())) + load_fn=(LoadDummySampleID()), + ) self.dset_id = DummyDatasetID(num_samples=6, - load_fn=LoadDummySampleID()) + load_fn=LoadDummySampleID(), + ) self.split = {"train": [0, 1, 2], "val": [3, 4, 5]} self.kfold = [{"train": [0, 1, 2], "val": [3, 4, 5]}, {"val": [0, 1, 2], "train": [3, 4, 5]}] diff --git a/tests/loading/test_dataset.py b/tests/loading/test_dataset.py index 76363d6e..0d2f3648 100644 --- a/tests/loading/test_dataset.py +++ b/tests/loading/test_dataset.py @@ -32,12 +32,18 @@ def load_mul_sample(path) -> list: return [LoadDummySample()(path, None)] * 4 dataset = CacheDataset(self.paths, load_mul_sample, + num_workers=0, verbose=True, mode='extend') self.assertEqual(len(dataset), 40) self.check_dataset_access(dataset, [0, 20, 39]) self.check_dataset_outside_access(dataset, [40, 45]) self.check_dataset_iter(dataset) + def test_cache_dataset_mode_error(self): + with self.assertRaises(TypeError): + dataset = CacheDataset(self.paths, LoadDummySample(), + label_load_fct=None, mode="no_mode:P") + def test_lazy_dataset(self): dataset = LazyDataset(self.paths, LoadDummySample(), label_load_fct=None) From b949d3797a423681578710a6f8ec3f5c93a72e6b Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 1 Dec 2019 20:28:27 +0100 Subject: [PATCH 31/39] dataset tests --- rising/loading/__init__.py | 1 + rising/loading/dataset.py | 10 +++--- rising/loading/debug_mode.py | 2 +- rising/loading/loader.py | 4 +-- rising/loading/splitter.py | 27 ++++++--------- tests/loading/test_dataset.py | 63 +++++++++++++++++++++++++++++++++++ 6 files changed, 82 insertions(+), 25 deletions(-) diff --git a/rising/loading/__init__.py b/rising/loading/__init__.py index 44a3deaa..135e4153 100644 --- a/rising/loading/__init__.py +++ b/rising/loading/__init__.py @@ -1,3 +1,4 @@ from rising.loading.collate import numpy_collate from rising.loading.dataset import Dataset from rising.loading.loader import DataLoader +from rising.loading.debug_mode import get_debug_mode, set_debug_mode, switch_debug_mode diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 6f831f20..0a0eeae5 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -8,7 +8,7 @@ import warnings from torch.utils.data import Dataset as TorchDset -from rising.loading.debug_mode import get_current_debug_mode +from rising.loading.debug_mode import get_debug_mode from torch.multiprocessing import Pool @@ -137,16 +137,14 @@ def __init__(self, """ super().__init__() - if (get_current_debug_mode() and - (num_workers is None or num_workers > 0)): + if get_debug_mode() and (num_workers is None or num_workers > 0): warnings.warn("The debug mode has been activated. " - "Falling back to num_workers = 0") + "Falling back to num_workers = 0", UserWarning) num_workers = 0 if (num_workers is None or num_workers > 0) and verbose: warnings.warn("Verbosity is mutually exclusive with " - "num_workers > 0. Setting it to False instead.", - UserWarning) + "num_workers > 0. Setting it to False instead.", UserWarning) verbose = False self._num_workers = num_workers diff --git a/rising/loading/debug_mode.py b/rising/loading/debug_mode.py index d0028c07..e94c8d00 100644 --- a/rising/loading/debug_mode.py +++ b/rising/loading/debug_mode.py @@ -10,7 +10,7 @@ # (even if this slows down things a lot!). -def get_current_debug_mode(): +def get_debug_mode(): """ Getter function for the current debug mode Returns diff --git a/rising/loading/loader.py b/rising/loading/loader.py index 27fe3350..886abdd4 100644 --- a/rising/loading/loader.py +++ b/rising/loading/loader.py @@ -5,7 +5,7 @@ from torch.utils.data.dataloader import \ _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter as \ __MultiProcessingDataLoaderIter -from rising.loading.debug_mode import get_current_debug_mode +from rising.loading.debug_mode import get_debug_mode from functools import partial from rising.loading.dataset import Dataset from threadpoolctl import threadpool_limits @@ -115,7 +115,7 @@ def __init__(self, dataset: Union[Sequence, Dataset], def __iter__(self) -> Union[_SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter]: - if self.num_workers == 0 or get_current_debug_mode(): + if self.num_workers == 0 or get_debug_mode(): return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) diff --git a/rising/loading/splitter.py b/rising/loading/splitter.py index efe02a7f..b38fd35b 100644 --- a/rising/loading/splitter.py +++ b/rising/loading/splitter.py @@ -1,6 +1,7 @@ import copy import typing import logging +import warnings import math from sklearn.model_selection import train_test_split, GroupShuffleSplit, \ KFold, GroupKFold, StratifiedKFold @@ -12,13 +13,11 @@ SplitType = typing.Dict[str, list] -# TODO: I would probably change this and make val_size optionally. -# We always need a testset, but not always a validationset class Splitter: def __init__(self, dataset: Dataset, - val_size: typing.Union[int, float], - test_size: typing.Union[int, float] = None): + val_size: typing.Union[int, float] = 0, + test_size: typing.Union[int, float] = 0): """ Splits a dataset by several options @@ -40,15 +39,17 @@ def __init__(self, if not provided or explicitly set to None, no testset will be created """ - # TODO: Since we only have object as implicit baseclass the - # super().__init__() can probably be removed super().__init__() + if val_size == 0 and test_size == 0: + warnings.warn("Can not perform splitting if val and test size is 0.") + self._dataset = dataset self._total_num = len(self._dataset) self._idx = list(range(self._total_num)) self._val = val_size - self._test = test_size if test_size is not None else 0 + self._test = test_size + self._convert_prop_to_num() self._check_sizes() def _check_sizes(self): @@ -71,12 +72,6 @@ def _check_sizes(self): raise ValueError("Size must be larger than zero, not " "{}".format(self._test)) - # TODO: Can we explicitly call this check in the __init__ before - # checking the sizes? Explicit is better than implicit - # When I first checked the Code I was wondering where this was done - # and I could not find it, since this function should only check - # and not convert anything - self._convert_prop_to_num() if self._total_num < self._val + self._test: raise ValueError("Val + test size must be smaller than total, " "not {}".format(self._val + self._test)) @@ -315,9 +310,7 @@ def _convert_prop_to_num(self, attributes: tuple = ("_val", "_test") """ for attr in attributes: value = getattr(self, attr) - # TODO: When is a value close to 0? - # Shouldn't we only check if 0<=value<1? - if value < 1 and math.isclose(value, 0): + if 0 < value < 1: setattr(self, attr, value * self._total_num) @staticmethod @@ -378,6 +371,7 @@ def val_size(self) -> int: @val_size.setter def val_size(self, value: typing.Union[int, float]): self._val = value + self._convert_prop_to_num() self._check_sizes() @property @@ -387,6 +381,7 @@ def test_size(self) -> int: @test_size.setter def test_size(self, value: typing.Union[int, float]): self._test = value + self._convert_prop_to_num() self._check_sizes() @property diff --git a/tests/loading/test_dataset.py b/tests/loading/test_dataset.py index 0d2f3648..2e3426ca 100644 --- a/tests/loading/test_dataset.py +++ b/tests/loading/test_dataset.py @@ -1,8 +1,13 @@ import unittest +import os +import tempfile +import shutil +import pickle import numpy as np from rising.loading.dataset import CacheDataset, LazyDataset, CacheDatasetID, \ LazyDatasetID +from rising.loading import get_debug_mode, set_debug_mode # TODO: Additional Tests for subsetdataset @@ -15,6 +20,39 @@ def __call__(self, path, *args, **kwargs): return data +def pickle_save(path, data): + with open(path, "wb") as f: + pickle.dump(data, f) + + +def pickle_load(path, *args, **kwargs): + with open(path, "rb") as f: + pickle.load(f) + + +class TestBaseDatasetDir(unittest.TestCase): + def setUp(self) -> None: + self.dir = tempfile.mkdtemp(dir=os.path.dirname(os.path.realpath(__file__))) + loader = LoadDummySample() + for i in range(10): + pickle_save(os.path.join(self.dir, f"sample{i}.pkl"), loader(i)) + + def tearDown(self) -> None: + shutil.rmtree(self.dir) + + def test_cache_dataset_dir(self): + dataset = CacheDataset(self.dir, pickle_load, label_load_fct=None) + self.assertEqual(len(dataset), 10) + for i in dataset: + pass + + def test_lazy_dataset_dir(self): + dataset = LazyDataset(self.dir, pickle_load, label_load_fct=None) + self.assertEqual(len(dataset), 10) + for i in dataset: + pass + + class TestBaseDataset(unittest.TestCase): def setUp(self): self.paths = list(range(10)) @@ -27,6 +65,21 @@ def test_cache_dataset(self): self.check_dataset_outside_access(dataset, [10, 20]) self.check_dataset_iter(dataset) + def test_cache_num_worker_warn(self): + set_debug_mode(True) + with self.assertWarns(UserWarning): + dataset = CacheDataset(self.paths, LoadDummySample(), + num_workers=4, + label_load_fct=None) + set_debug_mode(False) + + def test_cache_verbose_warn(self): + + with self.assertWarns(UserWarning): + dataset = CacheDataset(self.paths, LoadDummySample(), + num_workers=4, verbose=True, + label_load_fct=None) + def test_cache_dataset_extend(self): def load_mul_sample(path) -> list: return [LoadDummySample()(path, None)] * 4 @@ -75,6 +128,16 @@ def check_dataset_iter(self, dataset): except BaseException: raise AssertionError('Dataset iteration failed.') + def test_subset_dataset(self): + idx = [0, 1, 2, 5, 6] + dataset = CacheDataset(self.paths, LoadDummySample(), + label_load_fct=None) + subset = dataset.get_subset(idx) + self.assertEqual(len(subset), len(idx)) + for _i, _idx in enumerate(idx): + self.assertEqual(subset[_i]["id"], dataset[_idx]["id"]) + with self.assertRaises(IndexError): + subset[len(idx)] class TestDatasetID(unittest.TestCase): def test_load_dummy_sample(self): From 8f7ffa40c7e88020984564a82d35afc9a54d7749 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 1 Dec 2019 21:12:17 +0100 Subject: [PATCH 32/39] collate and debug mode tests --- rising/loading/debug_mode.py | 2 +- tests/loading/test_collate.py | 55 +++++++++++++++++++++++++++++--- tests/loading/test_debug_mode.py | 22 ++++++++++++- 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/rising/loading/debug_mode.py b/rising/loading/debug_mode.py index e94c8d00..a75485e1 100644 --- a/rising/loading/debug_mode.py +++ b/rising/loading/debug_mode.py @@ -25,7 +25,7 @@ def switch_debug_mode(): """ Alternates the current debug mode """ - set_debug_mode(not get_current_debug_mode()) + set_debug_mode(not get_debug_mode()) def set_debug_mode(mode: bool): diff --git a/tests/loading/test_collate.py b/tests/loading/test_collate.py index 86e3cb22..275fdb25 100644 --- a/tests/loading/test_collate.py +++ b/tests/loading/test_collate.py @@ -1,5 +1,6 @@ import torch import unittest +from collections import namedtuple try: import numpy as np @@ -12,7 +13,7 @@ # TODO: Add more collate test cases class TestCollate(unittest.TestCase): @unittest.skipIf(np is None, 'numpy is not available') - def test_default_collate_int(self): + def test_numpy_collate_int(self): arr = [1, 2, -1] collated = numpy_collate(arr) expected = np.array(arr) @@ -20,7 +21,7 @@ def test_default_collate_int(self): self.assertEqual(collated.dtype, expected.dtype) @unittest.skipIf(np is None, 'numpy is not available') - def test_default_collate_float(self): + def test_numpy_collate_float(self): arr = [1.1, 2.3, -0.9] collated = numpy_collate(arr) expected = np.array(arr) @@ -28,18 +29,64 @@ def test_default_collate_float(self): self.assertEqual(collated.dtype, expected.dtype) @unittest.skipIf(np is None, 'numpy is not available') - def test_default_collate_bool(self): + def test_numpy_collate_bool(self): arr = [True, False] collated = numpy_collate(arr) self.assertTrue(all(collated == np.array(arr))) self.assertEqual(collated.dtype, np.bool) @unittest.skipIf(np is None, 'numpy is not available') - def test_default_collate_str(self): + def test_numpy_collate_str(self): # Should be a no-op arr = ['a', 'b', 'c'] self.assertTrue((arr == numpy_collate(arr))) + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_ndarray(self): + arr = [np.array(0), np.array(1), np.array(2)] + collated = numpy_collate(arr) + expected = np.array([0, 1, 2]) + self.assertTrue((collated == expected).all()) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_tensor(self): + arr = [torch.tensor(0), torch.tensor(1), torch.tensor(2)] + collated = numpy_collate(arr) + expected = np.array([0, 1, 2]) + self.assertTrue((collated == expected).all()) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_mapping(self): + arr = [{"a": np.array(0), "b": np.array(1)}] * 2 + collated = numpy_collate(arr) + expected = {"a": np.array([0, 0]), "b": np.array([1, 1])} + for key in expected.keys(): + self.assertTrue((collated[key] == expected[key]).all()) + self.assertEqual(len(expected.keys()), len(collated.keys())) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_sequence(self): + arr = [[np.array(0), np.array(1)], [np.array(0), np.array(1)]] + collated = numpy_collate(arr) + expected = [np.array([0, 0]), np.array([1, 1])] + for i in range(len(collated)): + self.assertTrue((collated[i] == expected[i]).all()) + self.assertEqual(len(expected), len(collated)) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_error(self): + with self.assertRaises(TypeError): + collated = numpy_collate([{"a", "b"}, {"a", "b"}]) + + @unittest.skipIf(np is None, 'numpy is not available') + def test_numpy_collate_named_tuple(self): + Point = namedtuple('Point', ['x', 'y']) + arr = [Point(0, 1), Point(2, 3)] + collated = numpy_collate(arr) + expected = Point(np.array([0, 2]), np.array([1, 3])) + self.assertTrue((collated.x == expected.x).all()) + self.assertTrue((collated.y == expected.y).all()) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/loading/test_debug_mode.py b/tests/loading/test_debug_mode.py index 8051fd04..e573d73d 100644 --- a/tests/loading/test_debug_mode.py +++ b/tests/loading/test_debug_mode.py @@ -1 +1,21 @@ -# TODO: Add tests for debug mode \ No newline at end of file +import unittest + +from rising.loading import get_debug_mode, set_debug_mode, switch_debug_mode + + +class TestDebugMode(unittest.TestCase): + def test_debug_mode(self): + set_debug_mode(False) + self.assertFalse(get_debug_mode()) + set_debug_mode(True) + self.assertTrue(get_debug_mode()) + set_debug_mode(False) + self.assertFalse(get_debug_mode()) + + def test_switch_debug_mode(self): + set_debug_mode(False) + self.assertFalse(get_debug_mode()) + switch_debug_mode() + self.assertTrue(get_debug_mode()) + switch_debug_mode() + self.assertFalse(get_debug_mode()) From 91bef3011bdfeebca6cddd8eb9be4a3283727b65 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 1 Dec 2019 21:13:56 +0100 Subject: [PATCH 33/39] tqdm to requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 771408bb..5b2778d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ numpy torch threadpoolctl pandas -sklearn \ No newline at end of file +sklearn +tqdm \ No newline at end of file From 1de2e8e20777aacdc1e4c04c53739c2cc0375fe1 Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Sun, 1 Dec 2019 20:15:00 +0000 Subject: [PATCH 34/39] autopep8 fix --- tests/loading/test_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/loading/test_dataset.py b/tests/loading/test_dataset.py index 2e3426ca..8cd1a88b 100644 --- a/tests/loading/test_dataset.py +++ b/tests/loading/test_dataset.py @@ -139,6 +139,7 @@ def test_subset_dataset(self): with self.assertRaises(IndexError): subset[len(idx)] + class TestDatasetID(unittest.TestCase): def test_load_dummy_sample(self): load_fn = LoadDummySample() @@ -200,4 +201,4 @@ def check_dset_id(self, dset): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From a4024c8b98d5163753da3600e539a24adf0309b3 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Mon, 2 Dec 2019 21:04:02 +0100 Subject: [PATCH 35/39] remove print --- tests/transforms/test_kernel_transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/transforms/test_kernel_transforms.py b/tests/transforms/test_kernel_transforms.py index 623730d8..0b1a4dc8 100644 --- a/tests/transforms/test_kernel_transforms.py +++ b/tests/transforms/test_kernel_transforms.py @@ -41,7 +41,6 @@ def test_gaussian_smoothing_transform(self): dim=2, stride=1, padding=1) self.batch_dict["data"][0, 0, 1] = 1 outp = trafo(**self.batch_dict) - print(outp["data"].shape) if __name__ == '__main__': From 344affcb774e68a55a7b0c1ce1642ad264ee6abe Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 3 Dec 2019 18:54:28 +0100 Subject: [PATCH 36/39] loader tests --- tests/loading/test_loader.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/loading/test_loader.py b/tests/loading/test_loader.py index a71ee5bc..58e5dfe5 100644 --- a/tests/loading/test_loader.py +++ b/tests/loading/test_loader.py @@ -1 +1,10 @@ -# TODO: Add Loader Tests \ No newline at end of file +import unittest + + +class MyTestCase(unittest.TestCase): + def test_something(self): + self.assertEqual(True, False) + + +if __name__ == '__main__': + unittest.main() From 5d352e514b9d06e4a515f6e99c687c9f264d7fea Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 3 Dec 2019 18:54:38 +0100 Subject: [PATCH 37/39] loader tests --- rising/loading/loader.py | 15 ++++--- tests/loading/test_loader.py | 83 ++++++++++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 10 deletions(-) diff --git a/rising/loading/loader.py b/rising/loading/loader.py index 886abdd4..b67a6507 100644 --- a/rising/loading/loader.py +++ b/rising/loading/loader.py @@ -9,10 +9,10 @@ from functools import partial from rising.loading.dataset import Dataset from threadpoolctl import threadpool_limits +import numpy as np class DataLoader(_DataLoader): - def __init__(self, dataset: Union[Sequence, Dataset], batch_size: int = 1, shuffle: bool = False, batch_transforms: Callable = None, sampler: Sampler = None, @@ -153,7 +153,6 @@ def __call__(self, *args, **kwargs) -> Any: batch = self._collate_fn(*args, **kwargs) if self._transforms is not None: - if isinstance(batch, Mapping): batch = self._transforms(**batch) elif isinstance(batch, Sequence): @@ -190,8 +189,7 @@ def __init__(self, loader): # generate numpy seed. The range comes so that the seed in each # worker (which is this baseseed plus the worker id) is always an # uint32. This is because numpy only accepts uint32 as valid seeds - npy_seed = np.random.randint(0, - (2 ** 32) - (1 + loader.num_workers)) + npy_seed = np.random.randint(0, (2 ** 32) - (1 + loader.num_workers)) except ImportError: # we don't generate a numpy seed here with torch, since we don't # need one; if the import fails in the main process it should @@ -200,9 +198,12 @@ def __init__(self, loader): old_worker_init = loader.worker_init_fn - new_worker_init_fn = partial(_seed_npy_before_worker_init, - seed=npy_seed, - worker_init_fn=old_worker_init) + if npy_seed is None: + new_worker_init_fn = old_worker_init + else: + new_worker_init_fn = partial(_seed_npy_before_worker_init, + seed=npy_seed, + worker_init_fn=old_worker_init) loader.worker_init_fn = new_worker_init_fn with threadpool_limits(limits=1, user_api='blas'): diff --git a/tests/loading/test_loader.py b/tests/loading/test_loader.py index 58e5dfe5..b8c7b8de 100644 --- a/tests/loading/test_loader.py +++ b/tests/loading/test_loader.py @@ -1,9 +1,86 @@ import unittest +import torch +from typing import Sequence, Mapping +import numpy as np +from unittest.mock import Mock, patch +from torch.utils.data.dataloader import _SingleProcessDataLoaderIter +from rising.loading.loader import _seed_npy_before_worker_init, DataLoader, \ + BatchTransformer, _MultiProcessingDataLoaderIter -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) + +class TestLoader(unittest.TestCase): + def test_seed_npy_before_worker_init(self): + expected_return = 100 + np.random.seed(1) + expected = np.random.rand(1) + worker_init = Mock(return_value=expected_return) + + output_return = _seed_npy_before_worker_init(worker_id=1, seed=0, + worker_init_fn=worker_init) + output = np.random.rand(1) + self.assertEqual(output, expected) + self.assertEqual(output_return, expected_return) + worker_init.assert_called_once_with(1) + + def test_seed_npy_before_worker_init_import_error(self): + with patch.dict('sys.modules', {'numpy': None}): + expected_return = 100 + worker_init = Mock(return_value=expected_return) + output_return = _seed_npy_before_worker_init(worker_id=1, seed=0, + worker_init_fn=worker_init) + self.assertEqual(output_return, expected_return) + worker_init.assert_called_once_with(1) + + def check_batch_transformer(self, collate_output): + collate = Mock(return_value=collate_output) + transforms = Mock(return_value=2) + transformer = BatchTransformer(collate_fn=collate, transforms=transforms, + auto_convert=False) + + output = transformer(0) + + collate.assert_called_once_with(0) + if isinstance(collate_output, Sequence): + transforms.assert_called_once_with(*collate_output) + elif isinstance(collate_output, Mapping): + transforms.assert_called_once_with(**collate_output) + else: + transforms.assert_called_once_with(collate_output) + self.assertEqual(2, output) + + def test_batch_transformer(self): + self.check_batch_transformer(0) + + def test_batch_transformer_sequence(self): + self.check_batch_transformer((0, 1)) + + def test_batch_transformer_mapping(self): + self.check_batch_transformer({"a": 0}) + + def test_batch_transformer_auto_convert(self): + collate = Mock(return_value=0) + transforms = Mock(return_value=np.array([0, 1])) + transformer = BatchTransformer(collate_fn=collate, transforms=transforms, + auto_convert=True) + output = transformer(0) + self.assertTrue((output == torch.tensor([0, 1])).all()) + + def test_dataloader_np_import_error(self): + with patch.dict('sys.modules', {'numpy': None}): + loader = DataLoader([0, 1, 2], num_workers=2) + iterator = iter(loader) + self.assertIsInstance(iterator, _MultiProcessingDataLoaderIter) + + def test_dataloader_single_process(self): + loader = DataLoader([0, 1, 2]) + iterator = iter(loader) + self.assertIsInstance(iterator, _SingleProcessDataLoaderIter) + + def test_dataloader_multi_process(self): + loader = DataLoader([0, 1, 2], num_workers=2) + iterator = iter(loader) + self.assertIsInstance(iterator, _MultiProcessingDataLoaderIter) if __name__ == '__main__': From 339e639e0eb0f9aa49b4e6d46ff92421ed5a57ec Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 3 Dec 2019 20:28:56 +0100 Subject: [PATCH 38/39] add splitter to coverage ignore, convert todos to github --- rising/loading/container.py | 8 ++------ rising/loading/splitter.py | 18 ++++++------------ rising/utils/checktype.py | 13 +++++++++++++ tests/loading/test_dataset.py | 2 -- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/rising/loading/container.py b/rising/loading/container.py index d73f388d..a1a49a6d 100644 --- a/rising/loading/container.py +++ b/rising/loading/container.py @@ -9,9 +9,8 @@ from rising.loading.splitter import SplitType -# TODO: Add docstrings for Datacontainer class DataContainer: - def __init__(self, dataset: Dataset, **kwargs): + def __init__(self, dataset: Dataset): """ Handles the splitting of datasets from different sources @@ -24,10 +23,7 @@ def __init__(self, dataset: Dataset, **kwargs): self._dataset = dataset self._dset = {} self._fold = None - # TODO: this does not make sense, the constructor of object - # (which is the implicit base class here) does not take arguments. - # We should rather set them as attributes. - super().__init__(**kwargs) + super().__init__() def split_by_index(self, split: SplitType) -> None: """ diff --git a/rising/loading/splitter.py b/rising/loading/splitter.py index b38fd35b..304391ac 100644 --- a/rising/loading/splitter.py +++ b/rising/loading/splitter.py @@ -2,7 +2,6 @@ import typing import logging import warnings -import math from sklearn.model_selection import train_test_split, GroupShuffleSplit, \ KFold, GroupKFold, StratifiedKFold @@ -110,10 +109,7 @@ def index_split(self, **kwargs) -> SplitType: self.log_split(split_dict, "Created Single Split with:") return split_dict - def index_split_stratified( - self, - stratify_key: str = "label", - **kwargs) -> SplitType: + def index_split_stratified(self, stratify_key: str = "label", **kwargs) -> SplitType: """ Splits the dataset's indices in a stratified way @@ -134,13 +130,9 @@ def index_split_stratified( """ stratify = [d[stratify_key] for d in self._dataset] - return self.index_split(stratify=stratify, **kwargs) - def index_split_grouped( - self, - groups_key: str = "id", - **kwargs) -> SplitType: + def index_split_grouped(self, groups_key: str = "id", **kwargs) -> SplitType: """ Splits the dataset's indices in a stratified way @@ -184,10 +176,10 @@ def index_split_grouped( self.log_split(split_dict, "Created Single Split with:") return split_dict - # TODO: Maybe add kfolds without fixed testset? def index_kfold_fixed_test(self, **kwargs) -> typing.Iterable[SplitType]: """ - Calculates splits for a random kfold with given testset + Calculates splits for a random kfold with given testset. + If :param:`test_size` is zero, a normal kfold is generated Parameters ---------- @@ -223,6 +215,7 @@ def index_kfold_fixed_test_stratified( **kwargs) -> typing.Iterable[SplitType]: """ Calculates splits for a stratified kfold with given testset + If :param:`test_size` is zero, a normal kfold is generated Parameters ---------- @@ -261,6 +254,7 @@ def index_kfold_fixed_test_grouped(self, groups_key: str = "id", **kwargs) -> typing.Iterable[SplitType]: """ Calculates splits for a stratified kfold with given testset + If :param:`test_size` is zero, a normal kfold is generated Parameters ---------- diff --git a/rising/utils/checktype.py b/rising/utils/checktype.py index 9eac016e..b0d09447 100644 --- a/rising/utils/checktype.py +++ b/rising/utils/checktype.py @@ -1,5 +1,18 @@ def check_scalar(x): + """ + Provide interface to check for scalars + + Parameters + ---------- + x: typing.Any + object to check for scalar + + Returns + ------- + bool + True if input is scalar + """ if isinstance(x, (int, float)): return True else: diff --git a/tests/loading/test_dataset.py b/tests/loading/test_dataset.py index 8cd1a88b..1afcd82a 100644 --- a/tests/loading/test_dataset.py +++ b/tests/loading/test_dataset.py @@ -10,8 +10,6 @@ from rising.loading import get_debug_mode, set_debug_mode -# TODO: Additional Tests for subsetdataset - class LoadDummySample: def __call__(self, path, *args, **kwargs): data = {'data': np.random.rand(1, 256, 256), From d701a938939a45c1cce21e119e07ecff8cbcbf6d Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Tue, 3 Dec 2019 20:41:58 +0100 Subject: [PATCH 39/39] remove comments which were moved to github --- rising/loading/container.py | 2 -- rising/loading/dataset.py | 10 ++-------- rising/loading/splitter.py | 4 ---- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/rising/loading/container.py b/rising/loading/container.py index a1a49a6d..9677f156 100644 --- a/rising/loading/container.py +++ b/rising/loading/container.py @@ -39,8 +39,6 @@ def split_by_index(self, split: SplitType) -> None: for key, idx in split.items(): self._dset[key] = self._dataset.get_subset(idx) - # TODO: Shouldn"t the kfold methods instead yield the current datasets - # instead of the whole cointainer? def kfold_by_index(self, splits: typing.Iterable[SplitType]): """ Produces kfold splits based on the given indices. diff --git a/rising/loading/dataset.py b/rising/loading/dataset.py index 0a0eeae5..f881901f 100644 --- a/rising/loading/dataset.py +++ b/rising/loading/dataset.py @@ -9,6 +9,7 @@ from torch.utils.data import Dataset as TorchDset from rising.loading.debug_mode import get_debug_mode +from rising import AbstractMixin from torch.multiprocessing import Pool @@ -328,10 +329,7 @@ def __len__(self) -> int: return len(self.data) -# TODO: Maybe we should add the dataset baseclass as baseclass of this as well -# (since it should just extend it and still have all the other dataset -# functionalities)? -class IDManager: +class IDManager(AbstractMixin): def __init__(self, id_key: str, cache_ids: bool = True, **kwargs): """ Helper class to add additional functionality to Datasets @@ -442,8 +440,6 @@ def __init__(self, data_path, load_fn, id_key, cache_ids=True, **kwargs : additional keyword arguments """ - # TODO: Shouldn't we call the baseclasses explicitly here? with super - # it is not clear, which baseclass is actually called super().__init__(data_path=data_path, load_fn=load_fn, id_key=id_key, cache_ids=cache_ids, **kwargs) @@ -467,7 +463,5 @@ def __init__(self, data_path, load_fn, id_key, cache_ids=True, **kwargs : additional keyword arguments """ - # TODO: Shouldn't we call the baseclasses explicitly here? with super - # it is not clear, which baseclass is actually called super().__init__(data_path=data_path, load_fn=load_fn, id_key=id_key, cache_ids=cache_ids, **kwargs) diff --git a/rising/loading/splitter.py b/rising/loading/splitter.py index 304391ac..019407d0 100644 --- a/rising/loading/splitter.py +++ b/rising/loading/splitter.py @@ -155,10 +155,6 @@ def index_split_grouped(self, groups_key: str = "id", **kwargs) -> SplitType: -------- Shuffling cannot be deactivated """ - # TODO: maybe we should implement a single split function, which - # handles random, stratificated and grouped splitting? This would not - # be hard at all (based on the first look at sklearn internals) and - # would remove some code duolication in here split_dict = {} groups = [d[groups_key] for d in self._dataset]