diff --git a/docs/source/refs.bib b/docs/source/refs.bib index d8365e71..7782ca52 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -207,3 +207,15 @@ @article{zou2005 journal = {Journal of the Royal Statistical Society Series B}, doi = {10.1111/j.1467-9868.2005.00527.x} } + +@article{Goldstein2014, + author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard}, + year={2014}, + month={Jan}, + pages={1588–1623}, + title={Fast Alternating Direction Optimization Methods}, + journal={SIAM Journal on Imaging Sciences}, + volume={7}, + ISSN={1936-4954}, + doi={10/gdwr49}, +} diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py index e0ac2572..d4e7082b 100644 --- a/modopt/opt/algorithms/__init__.py +++ b/modopt/opt/algorithms/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -r"""OPTIMISATION ALGOTITHMS. +r"""OPTIMISATION ALGORITHMS. This module contains class implementations of various optimisation algoritms. @@ -57,3 +57,4 @@ SAGAOptGradOpt, VanillaGenericGradOpt) from modopt.opt.algorithms.primal_dual import Condat +from modopt.opt.algorithms.admm import ADMM, FastADMM diff --git a/modopt/opt/algorithms/admm.py b/modopt/opt/algorithms/admm.py new file mode 100644 index 00000000..b881b770 --- /dev/null +++ b/modopt/opt/algorithms/admm.py @@ -0,0 +1,337 @@ +"""ADMM Algorithms.""" +import numpy as np + +from modopt.base.backend import get_array_module +from modopt.opt.algorithms.base import SetUp +from modopt.opt.cost import CostParent + + +class ADMMcostObj(CostParent): + r"""Cost Object for the ADMM problem class. + + Parameters + ---------- + cost_funcs: 2-tuples of callable + f and g function. + A : OperatorBase + First Operator + B : OperatorBase + Second Operator + b : numpy.ndarray + Observed data + **kwargs : dict + Extra parameters for cost operator configuration + + Notes + ----- + Compute :math:`f(u)+g(v) + \tau \| Au +Bv - b\|^2` + + See Also + -------- + CostParent: parent class + """ + + def __init__(self, cost_funcs, A, B, b, tau, **kwargs): + super().__init__(*kwargs) + self.cost_funcs = cost_funcs + self.A = A + self.B = B + self.b = b + self.tau = tau + + def _calc_cost(self, u, v, **kwargs): + """Calculate the cost. + + This method calculates the cost from each of the input operators. + + Parameters + ---------- + u: numpy.ndarray + First primal variable of ADMM + v: numpy.ndarray + Second primal variable of ADMM + + Returns + ------- + float + Cost value + + """ + xp = get_array_module(u) + cost = self.cost_funcs[0](u) + cost += self.cost_funcs[1](v) + cost += self.tau * xp.linalg.norm(self.A.op(u) + self.B.op(v) - self.b) + return cost + + +class ADMM(SetUp): + r"""Fast ADMM Optimisation Algorihm. + + This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1). + + Parameters + ---------- + u: numpy.ndarray + Initial value for first primal variable of ADMM + v: numpy.ndarray + Initial value for second primal variable of ADMM + mu: numpy.ndarray + Initial value for lagrangian multiplier. + A : modopt.opt.linear.LinearOperator + Linear operator for u + B: modopt.opt.linear.LinearOperator + Linear operator for v + b : numpy.ndarray + Constraint vector + optimizers: tuple + 2-tuple of callable, that are the optimizers for the u and v. + Each callable should access an init and obs argument and returns an estimate for: + .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 + .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 + cost_funcs: tuple + 2-tuple of callable, that compute values of H and G. + tau: float, default=1 + Coupling parameter for ADMM. + + Notes + ----- + The algorithm solve the problem: + + .. math:: u, v = \arg\min H(u) + G(v) + \frac\tau2 \|Au + Bv - b \|_2^2 + + with the following augmented lagrangian: + + .. math :: \mathcal{L}_{\tau}(u,v, \lambda) = H(u) + G(v) + +\langle\lambda |Au + Bv -b \rangle + \frac\tau2 \| Au + Bv -b \|^2 + + To allow easy iterative solving, the change of variable + :math:`\mu=\lambda/\tau` is used. Hence, the lagrangian of interest is: + + .. math :: \tilde{\mathcal{L}}_{\tau}(u,v, \mu) = H(u) + G(v) + + \frac\tau2 \left(\|\mu + Au +Bv - b\|^2 - \|\mu\|^2\right) + + See Also + -------- + SetUp: parent class + """ + + def __init__( + self, + u, + v, + mu, + A, + B, + b, + optimizers, + tau=1, + cost_funcs=None, + **kwargs, + ): + super().__init__(**kwargs) + self.A = A + self.B = B + self.b = b + self._opti_H = optimizers[0] + self._opti_G = optimizers[1] + self._tau = tau + if cost_funcs is not None: + self._cost_func = ADMMcostObj(cost_funcs, A, B, b, tau) + else: + self._cost_func = None + + # init iteration variables. + self._u_old = self.xp.copy(u) + self._u_new = self.xp.copy(u) + self._v_old = self.xp.copy(v) + self._v_new = self.xp.copy(v) + self._mu_new = self.xp.copy(mu) + self._mu_old = self.xp.copy(mu) + + def _update(self): + self._u_new = self._opti_H( + init=self._u_old, + obs=self.B.op(self._v_old) + self._u_old - self.b, + ) + tmp = self.A.op(self._u_new) + self._v_new = self._opti_G( + init=self._v_old, + obs=tmp + self._u_old - self.b, + ) + + self._mu_new = self._mu_old + (tmp + self.B.op(self._v_new) - self.b) + + # update cycle + self._u_old = self.xp.copy(self._u_new) + self._v_old = self.xp.copy(self._v_new) + self._mu_old = self.xp.copy(self._mu_new) + + # Test cost function for convergence. + if self._cost_func: + self.converge = self.any_convergence_flag() + self.converge |= self._cost_func.get_cost(self._u_new, self._v_new) + + def iterate(self, max_iter=150): + """Iterate. + + This method calls update until either convergence criteria is met or + the maximum number of iterations is reached. + + Parameters + ---------- + max_iter : int, optional + Maximum number of iterations (default is ``150``) + """ + self._run_alg(max_iter) + + # retrieve metrics results + self.retrieve_outputs() + # rename outputs as attributes + self.u_final = self._u_new + self.x_final = self.u_final # for backward compatibility + self.v_final = self._v_new + + def get_notify_observers_kwargs(self): + """Notify observers. + + Return the mapping between the metrics call and the iterated + variables. + + Returns + ------- + dict + The mapping between the iterated variables + """ + return { + 'x_new': self._u_new, + 'v_new': self._v_new, + 'idx': self.idx, + } + + def retrieve_outputs(self): + """Retrieve outputs. + + Declare the outputs of the algorithms as attributes: x_final, + y_final, metrics. + """ + metrics = {} + for obs in self._observers['cv_metrics']: + metrics[obs.name] = obs.retrieve_metrics() + self.metrics = metrics + + +class FastADMM(ADMM): + r"""Fast ADMM Optimisation Algorihm. + + This class implement the fast ADMM algorithm + (Algorithm 8 from :cite:`Goldstein2014`) + + Parameters + ---------- + u: numpy.ndarray + Initial value for first primal variable of ADMM + v: numpy.ndarray + Initial value for second primal variable of ADMM + mu: numpy.ndarray + Initial value for lagrangian multiplier. + A : modopt.opt.linear.LinearOperator + Linear operator for u + B: modopt.opt.linear.LinearOperator + Linear operator for v + b : numpy.ndarray + Constraint vector + optimizers: tuple + 2-tuple of callable, that are the optimizers for the u and v. + Each callable should access an init and obs argument and returns an estimate for: + .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 + .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 + cost_funcs: tuple + 2-tuple of callable, that compute values of H and G. + tau: float, default=1 + Coupling parameter for ADMM. + eta: float, default=0.999 + Convergence parameter for ADMM. + alpha: float, default=1. + Initial value for the FISTA-like acceleration parameter. + + Notes + ----- + This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm. + + See Also + -------- + ADMM: parent class + """ + + def __init__( + self, + u, + v, + mu, + A, + B, + b, + optimizers, + cost_funcs=None, + alpha=1, + eta=0.999, + tau=1, + **kwargs, + ): + super().__init__( + u=u, + v=b, + mu=mu, + A=A, + B=B, + b=b, + optimizers=optimizers, + cost_funcs=cost_funcs, + **kwargs, + ) + self._c_old = np.inf + self._c_new = 0 + self._eta = eta + self._alpha_old = alpha + self._alpha_new = alpha + self._v_hat = self.xp.copy(self._v_new) + self._mu_hat = self.xp.copy(self._mu_new) + + def _update(self): + # Classical ADMM steps + self._u_new = self._opti_H( + init=self._u_old, + obs=self.B.op(self._v_hat) + self._u_old - self.b, + ) + tmp = self.A.op(self._u_new) + self._v_new = self._opti_G( + init=self._v_hat, + obs=tmp + self._u_old - self.b, + ) + + self._mu_new = self._mu_hat + (tmp + self.B.op(self._v_new) - self.b) + + # restarting condition + self._c_new = self.xp.linalg.norm(self._mu_new - self._mu_hat) + self._c_new += self._tau * self.xp.linalg.norm( + self.B.op(self._v_new - self._v_hat), + ) + if self._c_new < self._eta * self._c_old: + self._alpha_new = 1 + np.sqrt(1 + 4 * self._alpha_old**2) + beta = (self._alpha_new - 1) / self._alpha_old + self._v_hat = self._v_new + (self._v_new - self._v_old) * beta + self._mu_hat = self._mu_new + (self._mu_new - self._mu_old) * beta + else: + # reboot to old iteration + self._alpha_new = 1 + self._v_hat = self._v_old + self._mu_hat = self._mu_old + self._c_new = self._c_old / self._eta + + self.xp.copyto(self._u_old, self._u_new) + self.xp.copyto(self._v_old, self._v_new) + self.xp.copyto(self._mu_old, self._mu_new) + # Test cost function for convergence. + if self._cost_func: + self.converge = self.any_convergence_flag() + self.convergd |= self._cost_func.get_cost(self._u_new, self._v_new) diff --git a/modopt/opt/cost.py b/modopt/opt/cost.py index 3cdfcc50..688a3959 100644 --- a/modopt/opt/cost.py +++ b/modopt/opt/cost.py @@ -6,6 +6,8 @@ """ +import abc + import numpy as np from modopt.base.backend import get_array_module @@ -13,8 +15,8 @@ from modopt.plot.cost_plot import plotCost -class costObj(object): - """Generic cost function object. +class CostParent(abc.ABC): + """Abstract cost function object. This class updates the cost according to the input operator classes and tests for convergence. @@ -40,7 +42,8 @@ class costObj(object): Notes ----- - The costFunc class must contain a method called ``cost``. + All child classes should implement a ``_calc_cost`` method (returning + a float) or a ``get_cost`` for more complex behavior on convergence test. Examples -------- @@ -71,7 +74,6 @@ class costObj(object): def __init__( self, - operators, initial_cost=1e6, tolerance=1e-4, cost_interval=1, @@ -80,9 +82,6 @@ def __init__( plot_output=None, ): - self._operators = operators - if not isinstance(operators, type(None)): - self._check_operators() self.cost = initial_cost self._cost_list = [] self._cost_interval = cost_interval @@ -93,30 +92,6 @@ def __init__( self._plot_output = plot_output self._verbose = verbose - def _check_operators(self): - """Check operators. - - This method checks if the input operators have a ``cost`` method. - - Raises - ------ - TypeError - For invalid operators type - ValueError - For operators without ``cost`` method - - """ - if not isinstance(self._operators, (list, tuple, np.ndarray)): - message = ( - 'Input operators must be provided as a list, not {0}' - ) - raise TypeError(message.format(type(self._operators))) - - for op in self._operators: - if not hasattr(op, 'cost'): - raise ValueError('Operators must contain "cost" method.') - op.cost = check_callable(op.cost) - def _check_cost(self): """Check cost function. @@ -167,6 +142,7 @@ def _check_cost(self): return False + @abc.abstractmethod def _calc_cost(self, *args, **kwargs): """Calculate the cost. @@ -178,14 +154,7 @@ def _calc_cost(self, *args, **kwargs): Positional arguments **kwargs : dict Keyword arguments - - Returns - ------- - float - Cost value - """ - return np.sum([op.cost(*args, **kwargs) for op in self._operators]) def get_cost(self, *args, **kwargs): """Get cost function. @@ -241,3 +210,110 @@ def plot_cost(self): # pragma: no cover """ plotCost(self._cost_list, self._plot_output) + + +class costObj(CostParent): + """Abstract cost function object. + + This class updates the cost according to the input operator classes and + tests for convergence. + + Parameters + ---------- + opertors : list, tuple or numpy.ndarray + List of operators classes containing ``cost`` method + initial_cost : float, optional + Initial value of the cost (default is ``1e6``) + tolerance : float, optional + Tolerance threshold for convergence (default is ``1e-4``) + cost_interval : int, optional + Iteration interval to calculate cost (default is ``1``). + If ``cost_interval`` is ``None`` the cost is never calculated, + thereby saving on computation time. + test_range : int, optional + Number of cost values to be used in test (default is ``4``) + verbose : bool, optional + Option for verbose output (default is ``True``) + plot_output : str, optional + Output file name for cost function plot + + Examples + -------- + >>> from modopt.opt.cost import * + >>> class dummy(object): + ... def cost(self, x): + ... return x ** 2 + ... + ... + >>> inst = costObj([dummy(), dummy()]) + >>> inst.get_cost(2) + - ITERATION: 1 + - COST: 8 + + False + >>> inst.get_cost(2) + - ITERATION: 2 + - COST: 8 + + False + >>> inst.get_cost(2) + - ITERATION: 3 + - COST: 8 + + False + """ + + def __init__( + self, + operators, + **kwargs, + ): + super().__init__(**kwargs) + + self._operators = operators + if not isinstance(operators, type(None)): + self._check_operators() + + def _check_operators(self): + """Check operators. + + This method checks if the input operators have a ``cost`` method. + + Raises + ------ + TypeError + For invalid operators type + ValueError + For operators without ``cost`` method + + """ + if not isinstance(self._operators, (list, tuple, np.ndarray)): + message = ( + 'Input operators must be provided as a list, not {0}' + ) + raise TypeError(message.format(type(self._operators))) + + for op in self._operators: + if not hasattr(op, 'cost'): + raise ValueError('Operators must contain "cost" method.') + op.cost = check_callable(op.cost) + + def _calc_cost(self, *args, **kwargs): + """Calculate the cost. + + This method calculates the cost from each of the input operators. + + Parameters + ---------- + *args : tuple + Positional arguments + **kwargs : dict + Keyword arguments + + Returns + ------- + float + Cost value + + """ + return np.sum([op.cost(*args, **kwargs) for op in self._operators]) diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index 73091acd..5671b8e3 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -80,7 +80,15 @@ def build_kwargs(kwargs, use_metrics): @parametrize(use_metrics=[True, False]) class AlgoCases: - """Cases for algorithms.""" + """Cases for algorithms. + + Most of the test solves the trivial problem + + .. math:: + \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0 + + More complex and concrete usecases are shown in examples. + """ data1 = np.arange(9).reshape(3, 3).astype(float) data2 = data1 + np.random.randn(*data1.shape) * 1e-6 @@ -103,7 +111,8 @@ class AlgoCases: ] ) def case_forward_backward(self, kwargs, idty, use_metrics): - """Forward Backward case.""" + """Forward Backward case. + """ update_kwargs = build_kwargs(kwargs, use_metrics) algo = algorithms.ForwardBackward( self.data1, @@ -233,7 +242,28 @@ def case_grad(self, GradDescent, use_metrics, idty): ) algo.iterate() return algo, update_kwargs + @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM]) + def case_admm(self, admm, use_metrics, idty): + """ADMM setup.""" + def optim1(init, obs): + return obs + + def optim2(init, obs): + return obs + update_kwargs = build_kwargs({}, use_metrics) + algo = admm( + u=self.data1, + v=self.data1, + mu=np.zeros_like(self.data1), + A=linear.Identity(), + B=linear.Identity(), + b=self.data1, + optimizers=(optim1, optim2), + **update_kwargs, + ) + algo.iterate() + return algo, update_kwargs @parametrize_with_cases("algo, kwargs", cases=AlgoCases) def test_algo(algo, kwargs): diff --git a/setup.cfg b/setup.cfg index 8d8e821b..100adb40 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,8 @@ per-file-ignores = modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410 #Todo: x is a too short name. modopt/opt/algorithms/forward_backward.py: WPS111 + #Todo: u,v , A is a too short name. + modopt/opt/algorithms/admm.py: WPS111, N803 #Todo: Check need for del statement modopt/opt/algorithms/primal_dual.py: WPS111, WPS420 #multiline parameters bug with tuples