Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,22 @@ @article{zou2005
journal = {Journal of the Royal Statistical Society Series B},
doi = {10.1111/j.1467-9868.2005.00527.x}
}

@article{ruder2017,
title = {An overview of gradient descent optimization algorithms},
url = {http://arxiv.org/abs/1609.04747},
journaltitle = {{arXiv}:1609.04747 [cs]},
author = {Ruder, Sebastian},
urldate = {2021-12-09},
date = {2017-06-15},
langid = {english},
eprinttype = {arxiv},
eprint = {1609.04747},
}

@article{defazio2014,
title = {{SAGA}: A Fast Incremental Gradient Method With Support for Non-Strongly Convex Composite Objectives},
pages = {15},
author = {Defazio, Aaron and Bach, Francis and Lacoste-Julien, Simon},
langid = {english},
}
56 changes: 56 additions & 0 deletions modopt/opt/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
r"""OPTIMISATION ALGOTITHMS.

This module contains class implementations of various optimisation algoritms.

:Authors: Samuel Farrens <samuel.farrens@cea.fr>,
Zaccharie Ramzi <zaccharie.ramzi@cea.fr>

:Notes:

Input classes must have the following properties:

* **Gradient Operators**

Must have the following methods:

* ``get_grad()`` - calculate the gradient

Must have the following variables:

* ``grad`` - the gradient

* **Linear Operators**

Must have the following methods:

* ``op()`` - operator
* ``adj_op()`` - adjoint operator

* **Proximity Operators**

Must have the following methods:

* ``op()`` - operator

The following notation is used to implement the algorithms:

* x_old is used in place of :math:`x_{n}`.
* x_new is used in place of :math:`x_{n+1}`.
* x_prox is used in place of :math:`\tilde{x}_{n+1}`.
* x_temp is used for intermediate operations.

"""

from modopt.opt.algorithms.base import SetUp
from modopt.opt.algorithms.forward_backward import (FISTA, POGM,
ForwardBackward,
GenForwardBackward)
from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt,
ADAMGradOpt,
GenericGradOpt,
MomentumGradOpt,
RMSpropGradOpt,
SAGAOptGradOpt,
VanillaGenericGradOpt)
from modopt.opt.algorithms.primal_dual import Condat
292 changes: 292 additions & 0 deletions modopt/opt/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
"""Base SetUp for optimisation algorithms."""

from inspect import getmro

import numpy as np
from progressbar import ProgressBar

from modopt.base import backend
from modopt.base.observable import MetricObserver, Observable
from modopt.interface.errors import warn


class SetUp(Observable):
r"""Algorithm Set-Up.

This class contains methods for checking the set-up of an optimisation
algotithm and produces warnings if they do not comply.

Parameters
----------
metric_call_period : int, optional
Metric call period (default is ``5``)
metrics : dict, optional
Metrics to be used (default is ``\{\}``)
verbose : bool, optional
Option for verbose output (default is ``False``)
progress : bool, optional
Option to display progress bar (default is ``True``)
step_size : int, optional
Generic step size parameter to override default algorithm
parameter name (`e.g.` `step_size` will override the value set for
`beta_param` in `ForwardBackward`)
use_gpu : bool, optional
Option to use available GPU

See Also
--------
modopt.base.observable.MetricObserver :
Definition of Metrics.
"""

def __init__(
self,
metric_call_period=5,
metrics=None,
verbose=False,
progress=True,
step_size=None,
compute_backend='numpy',
**dummy_kwargs,
):
self.idx = 0
self.converge = False
self.verbose = verbose
self.progress = progress
self.metrics = metrics
self.step_size = step_size
self._op_parents = (
'GradParent',
'ProximityParent',
'LinearParent',
'costObj',
)

self.metric_call_period = metric_call_period

# Declaration of observers for metrics
super().__init__(['cv_metrics'])

for name, dic in self.metrics.items():
observer = MetricObserver(
name,
dic['metric'],
dic['mapping'],
dic['cst_kwargs'],
dic['early_stopping'],
)
self.add_observer('cv_metrics', observer)

xp, compute_backend = backend.get_backend(compute_backend)
self.xp = xp
self.compute_backend = compute_backend

@property
def metrics(self):
"""Metrics."""
return self._metrics

@metrics.setter
def metrics(self, metrics):

if isinstance(metrics, type(None)):
self._metrics = {}
elif isinstance(metrics, dict):
self._metrics = metrics
else:
raise TypeError(
'Metrics must be a dictionary, not {0}.'.format(type(metrics)),
)

def any_convergence_flag(self):
"""Check convergence flag.

Return if any matrices values matched the convergence criteria.

Returns
-------
bool
True if any convergence criteria met

"""
return any(
obs.converge_flag for obs in self._observers['cv_metrics']
)

def copy_data(self, input_data):
"""Copy Data.

Set directive for copying data.

Parameters
----------
input_data : numpy.ndarray
Input data

Returns
-------
numpy.ndarray
Copy of input data

"""
return self.xp.copy(backend.change_backend(
input_data,
self.compute_backend,
))

def _check_input_data(self, input_data):
"""Check input data type.

This method checks if the input data is a numpy array

Parameters
----------
input_data : numpy.ndarray
Input data array

Raises
------
TypeError
For invalid input type

"""
if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))):
raise TypeError(
'Input data must be a numpy array or backend array',
)

def _check_param(self, param_val):
"""Check algorithm parameters.

This method checks if the specified algorithm parameters are floats

Parameters
----------
param_val : float
Parameter value

Raises
------
TypeError
For invalid input type

"""
if not isinstance(param_val, float):
raise TypeError('Algorithm parameter must be a float value.')

def _check_param_update(self, param_update):
"""Check algorithm parameter update methods.

This method checks if the specified algorithm parameters are floats

Parameters
----------
param_update : function
Callable function

Raises
------
TypeError
For invalid input type

"""
param_conditions = (
not isinstance(param_update, type(None))
and not callable(param_update)
)

if param_conditions:
raise TypeError(
'Algorithm parameter update must be a callabale function.',
)

def _check_operator(self, operator):
"""Check set-Up.

This method checks algorithm operator against the expected parent
classes

Parameters
----------
operator : str
Algorithm operator to check

"""
if not isinstance(operator, type(None)):
tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]

if not any(parent in tree for parent in self._op_parents):
message = '{0} does not inherit an operator parent.'
warn(message.format(str(operator.__class__)))

def _compute_metrics(self):
"""Compute metrics during iteration.

This method create the args necessary for metrics computation, then
call the observers to compute metrics

"""
kwargs = self.get_notify_observers_kwargs()
self.notify_observers('cv_metrics', **kwargs)

def _iterations(self, max_iter, progbar=None):
"""Iterate method.

Iterate the update step of the given algorithm.

Parameters
----------
max_iter : int
Maximum number of iterations
progbar : progressbar.ProgressBar
Progress bar (default is ``None``)

"""
for idx in range(max_iter):
self.idx = idx
self._update()

# Calling metrics every metric_call_period cycle
# Also calculate at the end (max_iter or at convergence)
# We do not call metrics if metrics is empty or metric call
# period is None
if self.metrics and self.metric_call_period is not None:

metric_conditions = (
self.idx % self.metric_call_period == 0
or self.idx == (max_iter - 1)
or self.converge,
)

if metric_conditions:
self._compute_metrics()

if self.converge:
if self.verbose:
print(' - Converged!')
break

if not isinstance(progbar, type(None)):
progbar.update(idx)

def _run_alg(self, max_iter):
"""Run algorithm.

Run the update step of a given algorithm up to the maximum number of
iterations.

Parameters
----------
max_iter : int
Maximum number of iterations

"""
if self.progress:
with ProgressBar(
redirect_stdout=True,
max_value=max_iter,
) as progbar:
self._iterations(max_iter, progbar=progbar)
else:
self._iterations(max_iter)
Loading