diff --git a/src/safeds/data/tabular/transformation/__init__.py b/src/safeds/data/tabular/transformation/__init__.py index 920f0b7e5..a837d3165 100644 --- a/src/safeds/data/tabular/transformation/__init__.py +++ b/src/safeds/data/tabular/transformation/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from ._discretizer import Discretizer from ._invertible_table_transformer import InvertibleTableTransformer + from ._k_nearest_neighbors_imputer import KNearestNeighborsImputer from ._label_encoder import LabelEncoder from ._one_hot_encoder import OneHotEncoder from ._range_scaler import RangeScaler @@ -27,6 +28,7 @@ "SimpleImputer": "._simple_imputer:SimpleImputer", "StandardScaler": "._standard_scaler:StandardScaler", "TableTransformer": "._table_transformer:TableTransformer", + "KNearestNeighborsImputer": "._k_nearest_neighbors_imputer:KNearestNeighborsImputer", }, ) @@ -40,4 +42,5 @@ "SimpleImputer", "StandardScaler", "TableTransformer", + "KNearestNeighborsImputer", ] diff --git a/src/safeds/data/tabular/transformation/_k_nearest_neighbors_imputer.py b/src/safeds/data/tabular/transformation/_k_nearest_neighbors_imputer.py new file mode 100644 index 000000000..890749ba4 --- /dev/null +++ b/src/safeds/data/tabular/transformation/_k_nearest_neighbors_imputer.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from safeds._utils import _structural_hash +from safeds._validation import _check_bounds, _check_columns_exist, _ClosedBound +from safeds.data.tabular.containers import Table +from safeds.exceptions import TransformerNotFittedError + +from ._table_transformer import TableTransformer + +if TYPE_CHECKING: + from sklearn.impute import KNNImputer as sk_KNNImputer + + +class KNearestNeighborsImputer(TableTransformer): + """ + The KNearestNeighborsImputer replaces missing values in given Columns with the mean value of the K-nearest neighbors. + + Parameters + ---------- + neighbor_count: + The number of neighbors to consider when imputing missing values. + column_names: + The list of columns used to impute missing values. If 'None', all columns are used. + value_to_replace: + The placeholder for the missing values. All occurrences of`missing_values` will be imputed. + """ + + # ------------------------------------------------------------------------------------------------------------------ + # Dunder methods + # ------------------------------------------------------------------------------------------------------------------ + + def __init__( + self, + neighbor_count: int, + *, + column_names: str | list[str] | None = None, + value_to_replace: float | str | None = None, + ) -> None: + super().__init__(column_names) + + _check_bounds(name="neighbor_count", actual=neighbor_count, lower_bound=_ClosedBound(1)) + + # parameter + self._neighbor_count: int = neighbor_count + self._value_to_replace: float | str | None = value_to_replace + + # attributes + self._wrapped_transformer: sk_KNNImputer | None = None + + def __hash__(self) -> int: + return _structural_hash( + super().__hash__(), + self._neighbor_count, + self._value_to_replace, + # Leave out the internal state for faster hashing + ) + + # ------------------------------------------------------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------------------------------------------------------ + + @property + def is_fitted(self) -> bool: + """Whether the transformer is fitted.""" + return self._wrapped_transformer is not None + + @property + def neighbor_count(self) -> int: + """The number of neighbors to consider when imputing missing values.""" + return self._neighbor_count + + @property + def value_to_replace(self) -> float | str | None: + """The value to replace.""" + return self._value_to_replace + + # ------------------------------------------------------------------------------------------------------------------ + # Learning and transformation + # ------------------------------------------------------------------------------------------------------------------ + + def fit(self, table: Table) -> KNearestNeighborsImputer: + """ + Learn a transformation for a set of columns in a table. + + **Note:** This transformer is not modified. + + Parameters + ---------- + table: + The table used to fit the transformer. + + Returns + ------- + fitted_transformer: + The fitted transformer. + + Raises + ------ + ColumnNotFoundError + If one of the columns, that should be fitted is not in the table. + """ + from sklearn.impute import KNNImputer as sk_KNNImputer + + if table.row_count == 0: + raise ValueError("The KNearestNeighborsImputer cannot be fitted because the table contains 0 rows.") + + if self._column_names is None: + column_names = table.column_names + else: + column_names = self._column_names + _check_columns_exist(table, column_names) + + value_to_replace = self._value_to_replace + + if self._value_to_replace is None: + from numpy import nan + + value_to_replace = nan + + wrapped_transformer = sk_KNNImputer(n_neighbors=self._neighbor_count, missing_values=value_to_replace) + wrapped_transformer.set_output(transform="polars") + wrapped_transformer.fit( + table.remove_columns_except(column_names)._data_frame, + ) + + result = KNearestNeighborsImputer(self._neighbor_count, column_names=column_names) + result._wrapped_transformer = wrapped_transformer + + return result + + def transform(self, table: Table) -> Table: + """ + Apply the learned transformation to a table. + + **Note:** The given table is not modified. + + Parameters + ---------- + table: + The table to wich the learned transformation is applied. + + Returns + ------- + transformed_table: + The transformed table. + + Raises + ------ + TransformerNotFittedError + If the transformer is not fitted. + ColumnNotFoundError + If one of the columns, that should be transformed is not in the table. + """ + if self._column_names is None or self._wrapped_transformer is None: + raise TransformerNotFittedError + + _check_columns_exist(table, self._column_names) + + new_data = self._wrapped_transformer.transform( + table.remove_columns_except(self._column_names)._data_frame, + ) + + return Table._from_polars_lazy_frame( + table._lazy_frame.with_columns(new_data), + ) diff --git a/tests/safeds/data/tabular/transformation/test_k_nearest_neighbors_imputer.py b/tests/safeds/data/tabular/transformation/test_k_nearest_neighbors_imputer.py new file mode 100644 index 000000000..de0700dbf --- /dev/null +++ b/tests/safeds/data/tabular/transformation/test_k_nearest_neighbors_imputer.py @@ -0,0 +1,211 @@ +import pytest +from safeds.data.tabular.containers import Table +from safeds.data.tabular.transformation import KNearestNeighborsImputer +from safeds.exceptions import ( + ColumnNotFoundError, + OutOfBoundsError, + TransformerNotFittedError, +) + + +class TestInit: + def test_should_raise_value_error(self) -> None: + with pytest.raises(OutOfBoundsError): + KNearestNeighborsImputer(neighbor_count=0) + + def test_neighbor_count(self) -> None: + knn = KNearestNeighborsImputer(neighbor_count=5) + assert knn.neighbor_count == 5 + + def test_value_to_replace_none(self) -> None: + knn = KNearestNeighborsImputer(neighbor_count=5) + assert knn.value_to_replace is None + + def test_value_to_replace_number(self) -> None: + knn = KNearestNeighborsImputer(neighbor_count=5, value_to_replace=1) + assert knn.value_to_replace == 1 + + +class TestFit: + def test_should_raise_if_column_not_found(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + with pytest.raises(ColumnNotFoundError): + KNearestNeighborsImputer(neighbor_count=5, column_names=["col2", "col3"]).fit(table) + + def test_should_raise_if_table_contains_no_rows(self) -> None: + with pytest.raises( + ValueError, + match=r"The KNearestNeighborsImputer cannot be fitted because the table contains 0 rows", + ): + KNearestNeighborsImputer(neighbor_count=5).fit(Table({"col1": []})) + + def test_should_not_change_original_transformer(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + transformer = KNearestNeighborsImputer(neighbor_count=5) + transformer.fit(table) + + assert transformer._column_names is None + assert transformer._wrapped_transformer is None + + +class TestTransform: + def test_should_raise_if_column_not_found(self) -> None: + table_to_fit = Table( + { + "col1": [0.0, 5.0, 10.0], + "col2": [5.0, 50.0, 100.0], + }, + ) + + transformer = KNearestNeighborsImputer(neighbor_count=5) + + table_to_transform = Table( + { + "col3": ["a", "b", "c"], + }, + ) + + with pytest.raises(ColumnNotFoundError): + transformer.fit(table_to_fit).transform(table_to_transform) + + def test_should_raise_if_not_fitted(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + transformer = KNearestNeighborsImputer(neighbor_count=5) + + with pytest.raises(TransformerNotFittedError): + transformer.transform(table) + + +class TestIsFitted: + def test_should_return_false_before_fitting(self) -> None: + transformer = KNearestNeighborsImputer(neighbor_count=5) + assert not transformer.is_fitted + + def test_should_return_true_after_fitting(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + transformer = KNearestNeighborsImputer(neighbor_count=5) + fitted_transformer = transformer.fit(table) + assert fitted_transformer.is_fitted + + +class TestFitAndTransform: + @pytest.mark.parametrize( + ("table", "column_names", "expected"), + [ + ( + Table( + { + "col1": [1, 2, None], + "col2": [1, 2, 3], + }, + ), + ["col1"], + Table( + { + "col1": [1, 2, 2], # Assuming k=1, the nearest neighbor for the missing value is 2. + "col2": [1, 2, 3], + }, + ), + ), + ( + Table( + { + "col1": [1, 2, None, 4], + "col2": [1, 2, 3, 4], + }, + ), + ["col1"], + Table( + { + "col1": [1, 2, 2, 4], # Assuming k=1, the nearest neighbor for the missing value is 2. + "col2": [1, 2, 3, 4], + }, + ), + ), + ], + ids=["one_column", "two_columns"], + ) + def test_should_return_fitted_transformer_and_transformed_table( + self, + table: Table, + column_names: list[str] | None, # noqa: ARG002 + expected: Table, + ) -> None: + fitted_transformer, transformed_table = KNearestNeighborsImputer( + neighbor_count=1, + column_names=None, + value_to_replace=None, + ).fit_and_transform(table) + assert fitted_transformer.is_fitted + assert transformed_table == expected + + @pytest.mark.parametrize( + ("table", "column_names", "expected"), + [ + ( + Table( + { + "col1": [1, 2, None, 4], + "col2": [1, None, 3, 4], + }, + ), + ["col1"], + Table( + { + "col1": [1, 2, 7 / 3, 4], # Assuming k=1, the nearest neighbor for the missing value is 2. + "col2": [1, 8 / 3, 3, 4], + }, + ), + ), + ], + ids=["two_columns"], + ) + def test_should_return_fitted_transformer_and_transformed_table_with_correct_values( + self, + table: Table, + column_names: list[str] | None, # noqa: ARG002 + expected: Table, + ) -> None: + fitted_transformer, transformed_table = KNearestNeighborsImputer( + neighbor_count=3, + value_to_replace=None, + ).fit_and_transform(table) + assert fitted_transformer.is_fitted + assert transformed_table == expected + + def test_should_not_change_original_table(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + KNearestNeighborsImputer(neighbor_count=5).fit_and_transform(table) + + expected = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + assert table == expected diff --git a/tests/safeds/data/tabular/transformation/test_table_transformer.py b/tests/safeds/data/tabular/transformation/test_table_transformer.py index 83c374bd3..d03b157e1 100644 --- a/tests/safeds/data/tabular/transformation/test_table_transformer.py +++ b/tests/safeds/data/tabular/transformation/test_table_transformer.py @@ -4,6 +4,7 @@ from safeds.data.tabular.containers import Table from safeds.data.tabular.transformation import ( Discretizer, + KNearestNeighborsImputer, LabelEncoder, OneHotEncoder, RangeScaler, @@ -67,6 +68,7 @@ def transformers() -> list[TableTransformer]: + transformers_non_numeric() + [ SimpleImputer(strategy=SimpleImputer.Strategy.mode()), + KNearestNeighborsImputer(neighbor_count=3, value_to_replace=None), ] )