diff --git a/src/squlearn/optimizers/adam.py b/src/squlearn/optimizers/adam.py index 9f718367..1ef5d689 100644 --- a/src/squlearn/optimizers/adam.py +++ b/src/squlearn/optimizers/adam.py @@ -12,6 +12,8 @@ class Adam(OptimizerBase, SGDMixin): Possible options that can be set in the options dictionary are: * **tol** (float): Tolerance for the termination of the optimization (default: 1e-6) + * **break_condition** (str): Break when parameter update is below tol ('param_update') or when + the function value ('func_value') goes below tol (default: 'param_update') * **lr** (float, list, np.ndarray, callable): Learning rate. If float, the learning rate is constant. If list or np.ndarray, the learning rate is taken from the list or array. If callable, the learning rate is taken from the function. (default: 0.05) @@ -37,6 +39,9 @@ def __init__(self, options: dict = None, callback=default_callback) -> None: # options = {} self.tol = options.get("tol", 1e-6) + self.break_condition = options.get("break_condition", "param_update") + if self.break_condition not in ["param_update", "func_value"]: + raise ValueError("Break condition must be 'param_update' or 'func_value'.") self.lr = options.get("lr", 0.05) self.beta_1 = options.get("beta_1", 0.9) self.beta_2 = options.get("beta_2", 0.99) @@ -141,8 +146,12 @@ def minimize( self.callback(self.iteration, self.x, gradient, fval) # check termination - if np.linalg.norm(self.x - x_updated) < self.tol: - break + if self.break_condition == "func_value": # func_value + if fval is not None and fval < self.tol: + break + else: # param_update + if np.linalg.norm(self.x - x_updated) < self.tol: + break self.x = x_updated