-
Notifications
You must be signed in to change notification settings - Fork 5
feat: Add InputConversion & OutputConversion for nn interface
#625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
77 commits
Select commit
Hold shift + click to select a range
019a6d8
perf: Add special case to `Table.add_rows` to increase performance
Marsmaennchen221 5471a03
style: apply automated linter fixes
megalinter-bot 0802e0e
perf: change number_of_rows to number_of_columns in `add_rows` as 0 c…
Marsmaennchen221 a177d97
perf: special case if rows has columns but no rows
Marsmaennchen221 8450dc6
test: Added test for `Table.add_rows` for "same schema add no rows"
Marsmaennchen221 bae8e4d
perf: suggested performance upgrades for nn._fnn_layer and nn._model
Marsmaennchen221 4ae795e
make dataloader shuffle data each epoch
sibre28 350b771
add learning_rate parameter to fit() function
sibre28 905f103
raise an Error if test data doesnt match format of train data
sibre28 7593204
add abstract layer class
sibre28 f5f291e
make forward return tensor instead of float and change method to buil…
sibre28 7284c89
remove uncoverable lines from codecov
sibre28 0572cb4
small change
sibre28 0eefee7
small change
sibre28 f5bdf22
add abstract functions
sibre28 e3543fb
change for linter
sibre28 41bdd6a
change for linter
sibre28 1487ea2
change for linter
sibre28 8df1f26
change for linter
sibre28 37238db
Merge branch 'main' into 610-improve-fnn-layer-and-model-performance-…
sibre28 03907fd
style: apply automated linter fixes
megalinter-bot 7f56aa2
style: apply automated linter fixes
megalinter-bot 0923989
change for linter
sibre28 4c64e53
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 625c0b6
change for linter
sibre28 5feeff5
style: apply automated linter fixes
megalinter-bot 3c98c23
accumulate epoch and batch counters and loss over all fit-calls
sibre28 6a05dd6
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 8191fc4
style: apply automated linter fixes
megalinter-bot 908409a
add input_size property to Layer
sibre28 2e253af
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 846a36d
raise InputSizeError if input size and table size mismatch
sibre28 c4c0965
style: apply automated linter fixes
megalinter-bot fe842e8
perf: suggested performance upgrades for dataloader in TaggedTable an…
Marsmaennchen221 4892d5d
rename FNNLayer to Forward Layer and put it in separate File
sibre28 d42e44f
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 0ca1ec8
style: apply automated linter fixes
megalinter-bot d39ada7
remove unnecessary test file
sibre28 f567310
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 a41c330
Merge remote-tracking branch 'origin/suggested_perf_upgrades_model_an…
sibre28 8b30b1a
Merge remote-tracking branch 'origin/suggested_perf_upgrades_tagged_t…
sibre28 bf74b67
merge suggested changes
sibre28 1d0bfa1
style: apply automated linter fixes
megalinter-bot f61c687
style: apply automated linter fixes
megalinter-bot 9a6d706
update test to cover into_dataloader_with_classes
sibre28 3db67cd
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 4100e6d
remove inplace modifications of model and reset loss after every epoch
sibre28 7ecec5b
adjust loss calculation
sibre28 7833e05
style: apply automated linter fixes
megalinter-bot b7da6df
loss_sum
sibre28 c8040ed
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 2501c4c
style: apply automated linter fixes
megalinter-bot 8ecb9fd
fix bug
sibre28 cbd69f0
style: apply automated linter fixes
megalinter-bot 8222b5f
fix bug
sibre28 080fe03
Merge remote-tracking branch 'origin/610-improve-fnn-layer-and-model-…
sibre28 99ec26c
style: apply automated linter fixes
megalinter-bot 2842732
Merge branch 'main' into 610-improve-fnn-layer-and-model-performance-…
sibre28 02e2996
style: apply automated linter fixes
megalinter-bot 265e55c
added input and output layer interface
Gerhardsa0 0ef68dc
Changes by alexg
Gerhardsa0 c4252dc
Changes by Simon
Gerhardsa0 c556f72
Merge branch 'main' of https://github.com/Safe-DS/Library into 621-fe…
Marsmaennchen221 d5015e1
refactor: linter
Marsmaennchen221 d7d41b2
refactor: codecov
Marsmaennchen221 d8f2551
refactor: linter
Marsmaennchen221 2e07146
style: apply automated linter fixes
megalinter-bot 05e0703
style: apply automated linter fixes
megalinter-bot f4a69de
refactor: lazy imports
Marsmaennchen221 8e0228c
style: apply automated linter fixes
megalinter-bot d02f03b
refactor: non global internal model creation
Marsmaennchen221 631ae38
Merge branch '621-feat-add-input-layer-for-nn-interface' of https://g…
Marsmaennchen221 ac631d8
feat: Added predict type to `InputConversion`
Marsmaennchen221 53b2b82
refactor: changed `TypeVar` to match correct classes
Marsmaennchen221 bea4d1a
style: apply automated linter fixes
megalinter-bot 9719467
added documentation
Gerhardsa0 3ecc3c7
Linter changes
Gerhardsa0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Generic, TypeVar | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from safeds.data.tabular.containers import Table, TaggedTable, TimeSeries | ||
|
|
||
| FT = TypeVar("FT", TaggedTable, TimeSeries) | ||
| PT = TypeVar("PT", Table, TimeSeries) | ||
|
|
||
|
|
||
| class _InputConversion(Generic[FT, PT], ABC): | ||
| """The input conversion for a neural network, defines the input parameters for the neural network.""" | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def _data_size(self) -> int: | ||
| pass # pragma: no cover | ||
|
|
||
| @abstractmethod | ||
| def _data_conversion_fit(self, input_data: FT, batch_size: int, num_of_classes: int = 1) -> DataLoader: | ||
| pass # pragma: no cover | ||
|
|
||
| @abstractmethod | ||
| def _data_conversion_predict(self, input_data: PT, batch_size: int) -> DataLoader: | ||
| pass # pragma: no cover | ||
|
|
||
| @abstractmethod | ||
| def _is_fit_data_valid(self, input_data: FT) -> bool: | ||
| pass # pragma: no cover | ||
|
|
||
| @abstractmethod | ||
| def _is_predict_data_valid(self, input_data: PT) -> bool: | ||
| pass # pragma: no cover | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from safeds.data.tabular.containers import Table, TaggedTable | ||
| from safeds.ml.nn._input_conversion import _InputConversion | ||
|
|
||
|
|
||
| class InputConversionTable(_InputConversion[TaggedTable, Table]): | ||
| """The input conversion for a neural network, defines the input parameters for the neural network.""" | ||
|
|
||
| def __init__(self, feature_names: list[str], target_name: str) -> None: | ||
| """ | ||
| Define the input parameters for the neural network in the input conversion. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| feature_names | ||
| The names of the features for the input table, used as features for the training. | ||
| target_name | ||
| The name of the target for the input table, used as target for the training. | ||
| """ | ||
| self._feature_names = feature_names | ||
| self._target_name = target_name | ||
|
|
||
| @property | ||
| def _data_size(self) -> int: | ||
| return len(self._feature_names) | ||
|
|
||
| def _data_conversion_fit(self, input_data: TaggedTable, batch_size: int, num_of_classes: int = 1) -> DataLoader: | ||
| return input_data._into_dataloader_with_classes( | ||
| batch_size, | ||
| num_of_classes, | ||
| ) | ||
|
|
||
| def _data_conversion_predict(self, input_data: Table, batch_size: int) -> DataLoader: | ||
| return input_data._into_dataloader(batch_size) | ||
|
|
||
| def _is_fit_data_valid(self, input_data: TaggedTable) -> bool: | ||
| return (sorted(input_data.features.column_names)).__eq__(sorted(self._feature_names)) | ||
|
|
||
| def _is_predict_data_valid(self, input_data: Table) -> bool: | ||
| return (sorted(input_data.column_names)).__eq__(sorted(self._feature_names)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.