From 308e1e3004ae202f4ba4460c220e8639edeeb1a4 Mon Sep 17 00:00:00 2001 From: Moritz Date: Wed, 12 Nov 2025 15:33:22 +0100 Subject: [PATCH 1/2] add break condition choice to adam minimize --- src/squlearn/optimizers/adam.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/squlearn/optimizers/adam.py b/src/squlearn/optimizers/adam.py index 9f718367..110f4f7f 100644 --- a/src/squlearn/optimizers/adam.py +++ b/src/squlearn/optimizers/adam.py @@ -12,6 +12,8 @@ class Adam(OptimizerBase, SGDMixin): Possible options that can be set in the options dictionary are: * **tol** (float): Tolerance for the termination of the optimization (default: 1e-6) + * **condition** (str): Break when parameter update is below tol ('param_update') or when + the function value ('func_value') goes below tol (default: 'param_update') * **lr** (float, list, np.ndarray, callable): Learning rate. If float, the learning rate is constant. If list or np.ndarray, the learning rate is taken from the list or array. If callable, the learning rate is taken from the function. (default: 0.05) @@ -37,6 +39,9 @@ def __init__(self, options: dict = None, callback=default_callback) -> None: # options = {} self.tol = options.get("tol", 1e-6) + self.condition = options.get("condition", "param_update") + if self.condition not in ["param_update", "func_value"]: + raise ValueError("Condition must be 'param_update' or 'func_value'.") self.lr = options.get("lr", 0.05) self.beta_1 = options.get("beta_1", 0.9) self.beta_2 = options.get("beta_2", 0.99) @@ -141,8 +146,12 @@ def minimize( self.callback(self.iteration, self.x, gradient, fval) # check termination - if np.linalg.norm(self.x - x_updated) < self.tol: - break + if self.condition == "func_value": # func_value + if fval is not None and fval < self.tol: + break + else: # param_update + if np.linalg.norm(self.x - x_updated) < self.tol: + break self.x = x_updated From 1f84f032ec39acd37a4b5d5ed392ad0ce4aa926b Mon Sep 17 00:00:00 2001 From: Moritz <44642314+MoritzWillmann@users.noreply.github.com> Date: Tue, 18 Nov 2025 11:06:48 +0100 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com> --- src/squlearn/optimizers/adam.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/squlearn/optimizers/adam.py b/src/squlearn/optimizers/adam.py index 110f4f7f..1ef5d689 100644 --- a/src/squlearn/optimizers/adam.py +++ b/src/squlearn/optimizers/adam.py @@ -12,7 +12,7 @@ class Adam(OptimizerBase, SGDMixin): Possible options that can be set in the options dictionary are: * **tol** (float): Tolerance for the termination of the optimization (default: 1e-6) - * **condition** (str): Break when parameter update is below tol ('param_update') or when + * **break_condition** (str): Break when parameter update is below tol ('param_update') or when the function value ('func_value') goes below tol (default: 'param_update') * **lr** (float, list, np.ndarray, callable): Learning rate. If float, the learning rate is constant. If list or np.ndarray, the learning rate is taken from the list or array. @@ -39,9 +39,9 @@ def __init__(self, options: dict = None, callback=default_callback) -> None: # options = {} self.tol = options.get("tol", 1e-6) - self.condition = options.get("condition", "param_update") - if self.condition not in ["param_update", "func_value"]: - raise ValueError("Condition must be 'param_update' or 'func_value'.") + self.break_condition = options.get("break_condition", "param_update") + if self.break_condition not in ["param_update", "func_value"]: + raise ValueError("Break condition must be 'param_update' or 'func_value'.") self.lr = options.get("lr", 0.05) self.beta_1 = options.get("beta_1", 0.9) self.beta_2 = options.get("beta_2", 0.99) @@ -146,7 +146,7 @@ def minimize( self.callback(self.iteration, self.x, gradient, fval) # check termination - if self.condition == "func_value": # func_value + if self.break_condition == "func_value": # func_value if fval is not None and fval < self.tol: break else: # param_update