Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions src/safeds/ml/classical/regression/_elastic_net_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from typing import TYPE_CHECKING
from warnings import warn

from sklearn.linear_model import ElasticNet as sk_ElasticNet

Expand All @@ -14,24 +15,59 @@


class ElasticNetRegression(Regressor):
"""Elastic net regression."""
"""Elastic net regression.

Parameters
----------
alpha : float
Controls the regularization of the model. The higher the value, the more regularized it becomes.

lasso_ratio: float
Number between 0 and 1 that controls the ratio between Lasso- and Ridge regularization.
lasso_ratio=0 is essentially RidgeRegression
lasso_ratio=1 is essentially LassoRegression

Raises
------
ValueError
If alpha is negative.
"""

def __init__(self, alpha: float = 1.0, lasso_ratio: float = 0.5) -> None:
if alpha < 0:
raise ValueError("alpha must be non-negative")
if alpha == 0:
warn(
(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use "
"LinearRegression instead for better numerical stability."
),
UserWarning,
stacklevel=2,
)

self._alpha = alpha

def __init__(self, lasso_ratio: float = 0.5) -> None:
if lasso_ratio < 0 or lasso_ratio > 1:
raise ValueError("lasso_ratio must be between 0 and 1.")
elif lasso_ratio == 0:
warnings.warn(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability.",
stacklevel=1,
(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability."
),
stacklevel=2,
)
elif lasso_ratio == 1:
warnings.warn(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability.",
stacklevel=1,
(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability."
),
stacklevel=2,
)
self.lasso_ratio = lasso_ratio

self._wrapped_regressor: sk_ElasticNet | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
Expand All @@ -57,10 +93,10 @@ def fit(self, training_set: TaggedTable) -> ElasticNetRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_ElasticNet(l1_ratio=self.lasso_ratio)
wrapped_regressor = sk_ElasticNet(alpha=self._alpha, l1_ratio=self.lasso_ratio)
fit(wrapped_regressor, training_set)

result = ElasticNetRegression(self.lasso_ratio)
result = ElasticNetRegression(alpha=self._alpha, lasso_ratio=self.lasso_ratio)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,68 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression._elastic_net_regression import ElasticNetRegression
from safeds.ml.classical.regression import ElasticNetRegression


def test_lasso_ratio_valid() -> None:
def test_should_throw_value_error_alpha() -> None:
with pytest.raises(ValueError, match="alpha must be non-negative"):
ElasticNetRegression(alpha=-1.0)


def test_should_throw_warning_alpha() -> None:
with pytest.warns(
UserWarning,
match=(
"Setting alpha to zero makes this model equivalent to LinearRegression. You "
"should use LinearRegression instead for better numerical stability."
),
):
ElasticNetRegression(alpha=0.0)


def test_should_give_alpha_to_sklearn() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])

elastic_net_regression = ElasticNetRegression(alpha=1.0).fit(tagged_training_set)
assert elastic_net_regression._wrapped_regressor is not None
assert elastic_net_regression._wrapped_regressor.alpha == elastic_net_regression._alpha


def test_should_give_lasso_ratio_to_sklearn() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])
lasso_ratio = 0.3

elastic_net_regression = ElasticNetRegression(lasso_ratio).fit(tagged_training_set)
elastic_net_regression = ElasticNetRegression(lasso_ratio=lasso_ratio).fit(tagged_training_set)
assert elastic_net_regression._wrapped_regressor is not None
assert elastic_net_regression._wrapped_regressor.l1_ratio == lasso_ratio


def test_lasso_ratio_invalid() -> None:
def test_should_throw_value_error_lasso_ratio() -> None:
with pytest.raises(ValueError, match="lasso_ratio must be between 0 and 1."):
ElasticNetRegression(-1)
ElasticNetRegression(lasso_ratio=-1.0)


def test_lasso_ratio_zero() -> None:
def test_should_throw_warning_lasso_ratio_zero() -> None:
with pytest.warns(
UserWarning,
match="ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability.",
match=(
"ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression."
" Use RidgeRegression instead for better numerical stability."
),
):
ElasticNetRegression(0)
ElasticNetRegression(lasso_ratio=0)


def test_lasso_ratio_one() -> None:
def test_should_throw_warning_lasso_ratio_one() -> None:
with pytest.warns(
UserWarning,
match="ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability.",
match=(
"ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression."
" Use LassoRegression instead for better numerical stability."
),
):
ElasticNetRegression(1)
ElasticNetRegression(lasso_ratio=1)


# (Default parameter is tested in `test_regressor.py`.)