This repository was archived by the owner on Jan 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
WIP: Loading #1
Merged
Merged
WIP: Loading #1
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
e35a0d9
Add loading package
justusschock 088bbe3
add collate fn
justusschock 540af3a
Update setup.py
justusschock 7cae604
update requirements
justusschock e880e16
add datasets
justusschock c08ef2a
add debug mode
justusschock 00c5730
add data loader
justusschock f76f6bc
add containers
justusschock ee1b2a4
splitter
justusschock 7f749c9
update requirements
justusschock adf282a
rename test directories
justusschock 4416632
add _utils for testing
justusschock ca7fde2
rename test directory
justusschock d8e35c4
pep8 changes
justusschock 1599a28
add dummy testcase for numpy_collate
justusschock f1096ed
add container tests
justusschock 197c4d7
add dataset tests
justusschock a046704
Add dummy files
justusschock 05d1e3a
Export Basic API (should we export more?)
justusschock 5943a18
autopep8 fix
mibaumgartner 6d3e5d9
Add imports in tests package
justusschock b95ee1c
document id managers
justusschock ba5017e
Move future imports to top of file
justusschock 50153e7
Add docstrings, todos and comments for splitter
justusschock 21af266
Add docstrings for data container
justusschock 9325788
add docstinrgs for DataContainerID
justusschock 1bef306
Add comment
justusschock 85946d9
move future imports to top of file
justusschock 1008e27
autopep8 fix
mibaumgartner 5edf840
fix bugs and tests
mibaumgartner b949d37
dataset tests
mibaumgartner 8f7ffa4
collate and debug mode tests
mibaumgartner 91bef30
tqdm to requirements
mibaumgartner 1de2e8e
autopep8 fix
mibaumgartner a4024c8
remove print
mibaumgartner 66e236c
merge master
mibaumgartner 344affc
loader tests
mibaumgartner 5d352e5
loader tests
mibaumgartner 339e639
add splitter to coverage ignore, convert todos to github
mibaumgartner d701a93
remove comments which were moved to github
mibaumgartner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,4 +11,5 @@ coverage: | |
| ignore: | ||
| - "tests/" | ||
| - "*/__init.py" | ||
| - "_version.py" | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,6 @@ | ||
| numpy | ||
| torch | ||
| threadpoolctl | ||
| pandas | ||
| sklearn | ||
| tqdm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from rising.loading.collate import numpy_collate | ||
| from rising.loading.dataset import Dataset | ||
| from rising.loading.loader import DataLoader | ||
| from rising.loading.debug_mode import get_debug_mode, set_debug_mode, switch_debug_mode |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import numpy as np | ||
| import torch | ||
| import collections.abc | ||
| from typing import Any | ||
|
|
||
|
|
||
| default_collate_err_msg_format = ( | ||
| "default_collate: batch must contain tensors, numpy arrays, numbers, " | ||
| "dicts or lists; found {}") | ||
|
|
||
|
|
||
| def numpy_collate(batch: Any) -> Any: | ||
| """ | ||
| function to collate the samples to a whole batch of numpy arrays. | ||
| PyTorch Tensors, scalar values and sequences will be casted to arrays | ||
| automatically. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch : Any | ||
| a batch of samples. In most cases this is either a sequence, | ||
| a mapping or a mixture of them | ||
|
|
||
| Returns | ||
| ------- | ||
| Any | ||
| collated batch with optionally converted type (to numpy array) | ||
|
|
||
| """ | ||
| elem = batch[0] | ||
| if isinstance(elem, np.ndarray): | ||
| return np.stack(batch, 0) | ||
| elif isinstance(elem, torch.Tensor): | ||
| return numpy_collate([b.detach().cpu().numpy() for b in batch]) | ||
| elif isinstance(elem, float) or isinstance(elem, int): | ||
| return np.array(batch) | ||
| elif isinstance(elem, str): | ||
| return batch | ||
| elif isinstance(elem, collections.abc.Mapping): | ||
| return {key: numpy_collate([d[key] for d in batch]) for key in elem} | ||
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | ||
| return type(elem)(*(numpy_collate(samples) for samples in zip(*batch))) | ||
| elif isinstance(elem, collections.abc.Sequence): | ||
| transposed = zip(*batch) | ||
| return [numpy_collate(samples) for samples in transposed] | ||
|
|
||
| raise TypeError(default_collate_err_msg_format.format(type(elem))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,268 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import pandas as pd | ||
| import typing | ||
| import pathlib | ||
| from collections import defaultdict | ||
|
|
||
| from rising.loading.dataset import Dataset | ||
| from rising.loading.splitter import SplitType | ||
|
|
||
|
|
||
| class DataContainer: | ||
| def __init__(self, dataset: Dataset): | ||
| """ | ||
| Handles the splitting of datasets from different sources | ||
|
|
||
| Parameters | ||
| ---------- | ||
| dataset : dataset | ||
| the dataset to split | ||
| kwargs | ||
| """ | ||
| self._dataset = dataset | ||
| self._dset = {} | ||
| self._fold = None | ||
| super().__init__() | ||
|
|
||
| def split_by_index(self, split: SplitType) -> None: | ||
| """ | ||
| Splits dataset by a given split-dict | ||
|
|
||
| Parameters | ||
| ---------- | ||
| split : dict | ||
| a dictionary containing tuples of strings and lists of indices | ||
| for each split | ||
|
|
||
| """ | ||
| for key, idx in split.items(): | ||
| self._dset[key] = self._dataset.get_subset(idx) | ||
|
|
||
| def kfold_by_index(self, splits: typing.Iterable[SplitType]): | ||
| """ | ||
| Produces kfold splits based on the given indices. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| splits : list | ||
| list containing split dicts for each fold | ||
|
|
||
| Yields | ||
| ------ | ||
| DataContainer | ||
| the data container with updated dataset splits | ||
|
|
||
| """ | ||
| for fold, split in enumerate(splits): | ||
| self.split_by_index(split) | ||
| self._fold = fold | ||
| yield self | ||
| self._fold = None | ||
|
|
||
| def split_by_csv(self, path: typing.Union[pathlib.Path, str], | ||
| index_column: str, **kwargs) -> None: | ||
| """ | ||
| Splits a dataset by splits given in a CSV file | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str, pathlib.Path | ||
| the path to the csv file | ||
| index_column : str | ||
| the label of the index column | ||
| **kwargs : | ||
| additional keyword arguments (see :func:`pandas.read_csv` for | ||
| details) | ||
|
|
||
| """ | ||
| df = pd.read_csv(path, **kwargs) | ||
| df = df.set_index(index_column) | ||
| col = list(df.columns) | ||
| self.split_by_index(self._read_split_from_df(df, col[0])) | ||
|
|
||
| def kfold_by_csv(self, path: typing.Union[pathlib.Path, str], | ||
| index_column: str, **kwargs) -> DataContainer: | ||
| """ | ||
| Produces kfold splits based on the given csv file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str, pathlib.Path | ||
| the path to the csv file | ||
| index_column : str | ||
| the label of the index column | ||
| **kwargs : | ||
| additional keyword arguments (see :func:`pandas.read_csv` for | ||
| details) | ||
|
|
||
| Yields | ||
| ------ | ||
| DataContainer | ||
| the data container with updated dataset splits | ||
|
|
||
| """ | ||
| df = pd.read_csv(path, **kwargs) | ||
| df = df.set_index(index_column) | ||
| folds = list(df.columns) | ||
| splits = [self._read_split_from_df(df, fold) for fold in folds] | ||
| yield from self.kfold_by_index((splits)) | ||
|
|
||
| @staticmethod | ||
| def _read_split_from_df(df: pd.DataFrame, col: str) -> SplitType: | ||
| """ | ||
| Helper function to read a split from a given data frame | ||
|
|
||
| Parameters | ||
| ---------- | ||
| df : pandas.DataFrame | ||
| the dataframe containing the split | ||
| col : str | ||
| the column inside the data frame containing the split | ||
|
|
||
| Returns | ||
| ------- | ||
| dict | ||
| a dictionary of lists. Contains a string-list-tuple per split | ||
|
|
||
| """ | ||
| split = defaultdict(list) | ||
| for index, row in df[[col]].iterrows(): | ||
| split[str(row[col])].append(index) | ||
| return split | ||
|
|
||
| @property | ||
| def dset(self) -> Dataset: | ||
| if not self._dset: | ||
| raise AttributeError("No Split found.") | ||
| else: | ||
| return self._dset | ||
|
|
||
| @property | ||
| def fold(self) -> int: | ||
| if self._fold is None: | ||
| raise AttributeError( | ||
| "Fold not specified. Call `kfold_by_index` first.") | ||
| else: | ||
| return self._fold | ||
|
|
||
|
|
||
| class DataContainerID(DataContainer): | ||
| """ | ||
| Data Container Class for datasets with an ID | ||
| """ | ||
|
|
||
| def split_by_id(self, split: SplitType) -> None: | ||
| """ | ||
| Splits the internal dataset by the given splits | ||
|
|
||
| Parameters | ||
| ---------- | ||
| split : dict | ||
| dictionary containing a string-list tuple per split | ||
|
|
||
| """ | ||
| split_idx = defaultdict(list) | ||
| for key, _id in split.items(): | ||
| for _i in _id: | ||
| split_idx[key].append(self._dataset.get_index_by_id(_i)) | ||
| return super().split_by_index(split_idx) | ||
|
|
||
| def kfold_by_id( | ||
| self, | ||
| splits: typing.Iterable[SplitType]): | ||
| """ | ||
| Produces kfold splits by an ID | ||
|
|
||
| Parameters | ||
| ---------- | ||
| splits : list | ||
| list of dicts each containing the splits for a separate fold | ||
|
|
||
| Yields | ||
| ------ | ||
| DataContaimnerID | ||
| the data container with updated internal datasets | ||
|
|
||
| """ | ||
| for fold, split in enumerate(splits): | ||
| self.split_by_id(split) | ||
| self._fold = fold | ||
| yield self | ||
| self._fold = None | ||
|
|
||
| def split_by_csv_id(self, path: typing.Union[pathlib.Path, str], | ||
| id_column: str, **kwargs) -> None: | ||
| """ | ||
| Splits the internal dataset by a given id column in a given csv file | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str or pathlib.Path | ||
| the path to the csv file | ||
| id_column : str | ||
| the key of the id_column | ||
| **kwargs : | ||
| additionalm keyword arguments (see :func:`pandas.read_csv` for | ||
| details) | ||
|
|
||
| """ | ||
| df = pd.read_csv(path, **kwargs) | ||
| df = df.set_index(id_column) | ||
| col = list(df.columns) | ||
| return self.split_by_id(self._read_split_from_df(df, col[0])) | ||
|
|
||
| def kfold_by_csv_id(self, path: typing.Union[pathlib.Path, str], | ||
| id_column: str, **kwargs): | ||
| """ | ||
| Produces kfold splits by an ID column of a given csv file | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str or pathlib.Path | ||
| the path to the csv file | ||
| id_column : str | ||
| the key of the id_column | ||
| **kwargs : | ||
| additionalm keyword arguments (see :func:`pandas.read_csv` for | ||
| details) | ||
|
|
||
| Yields | ||
| ------ | ||
| DataContaimnerID | ||
| the data container with updated internal datasets | ||
|
|
||
| """ | ||
| df = pd.read_csv(path, **kwargs) | ||
| df = df.set_index(id_column) | ||
| folds = list(df.columns) | ||
| splits = [self._read_split_from_df(df, fold) for fold in folds] | ||
| yield from self.kfold_by_id((splits)) | ||
|
|
||
| def save_split_to_csv_id(self, | ||
| path: typing.Union[pathlib.Path, str], | ||
| id_key: str, | ||
| split_column: str = 'split', | ||
| **kwargs) -> None: | ||
| """ | ||
| Saves a split top a given csv id | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str or pathlib.Path | ||
| the path of the csv file | ||
| id_key : str | ||
| the id key inside the csv file | ||
| split_column : str | ||
| the name of the split_column inside the csv file | ||
| **kwargs : | ||
| additional keyword arguments (see :meth:`pd.DataFrame.to_csv` | ||
| for details) | ||
|
|
||
| """ | ||
| split_dict = {str(id_key): [], str(split_column): []} | ||
| for key, item in self._dset.items(): | ||
| for sample in item: | ||
| split_dict[str(id_key)].append(sample[id_key]) | ||
| split_dict[str(split_column)].append(str(key)) | ||
| pd.DataFrame(split_dict).to_csv(path, **kwargs) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.