Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions modopt/opt/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from inspect import getmro

import numpy as np
from progressbar import ProgressBar
from tqdm.auto import tqdm

from modopt.base import backend
from modopt.base.observable import MetricObserver, Observable
Expand All @@ -15,7 +15,7 @@ class SetUp(Observable):
r"""Algorithm Set-Up.

This class contains methods for checking the set-up of an optimisation
algotithm and produces warnings if they do not comply.
algorithm and produces warnings if they do not comply.

Parameters
----------
Expand All @@ -38,7 +38,6 @@ class SetUp(Observable):
--------
modopt.base.observable.Observable : parent class
modopt.base.observable.MetricObserver : definition of metrics

"""

def __init__(
Expand Down Expand Up @@ -240,9 +239,8 @@ def _iterations(self, max_iter, progbar=None):
----------
max_iter : int
Maximum number of iterations
progbar : progressbar.bar.ProgressBar
Progress bar (default is ``None``)

progbar: tqdm.tqdm
Progress bar handle (default is ``None``)
"""
for idx in range(max_iter):
self.idx = idx
Expand All @@ -268,10 +266,10 @@ def _iterations(self, max_iter, progbar=None):
print(' - Converged!')
break

if not isinstance(progbar, type(None)):
progbar.update(idx)
if progbar:
progbar.update()

def _run_alg(self, max_iter):
def _run_alg(self, max_iter, progbar=None):
"""Run algorithm.

Run the update step of a given algorithm up to the maximum number of
Expand All @@ -281,17 +279,34 @@ def _run_alg(self, max_iter):
----------
max_iter : int
Maximum number of iterations
progbar: tqdm.tqdm
Progress bar handle (default is ``None``)

See Also
--------
progressbar.bar.ProgressBar
tqdm.tqdm

"""
if self.progress:
with ProgressBar(
redirect_stdout=True,
max_value=max_iter,
) as progbar:
self._iterations(max_iter, progbar=progbar)
if self.progress and progbar is None:
with tqdm(total=max_iter) as pb:
self._iterations(max_iter, progbar=pb)
elif progbar:
self._iterations(max_iter, progbar=progbar)
else:
self._iterations(max_iter)

def _update(self):
raise NotImplementedError

def get_notify_observers_kwargs(self):
"""Notify Observers.

Return the mapping between the metrics call and the iterated
variables.

Raises
------
NotImplementedError
This method should be overriden by subclasses.
"""
raise NotImplementedError
21 changes: 12 additions & 9 deletions modopt/opt/algorithms/forward_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _update(self):
or self._cost_func.get_cost(self._x_new)
)

def iterate(self, max_iter=150):
def iterate(self, max_iter=150, progbar=None):
"""Iterate.

This method calls update until either the convergence criteria is met
Expand All @@ -477,9 +477,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)

progbar: tqdm.tqdm
Progress bar handle (default is ``None``)
"""
self._run_alg(max_iter)
self._run_alg(max_iter, progbar)

# retrieve metrics results
self.retrieve_outputs()
Expand Down Expand Up @@ -750,7 +751,7 @@ def _update(self):
if self._cost_func:
self.converge = self._cost_func.get_cost(self._x_new)

def iterate(self, max_iter=150):
def iterate(self, max_iter=150, progbar=None):
"""Iterate.

This method calls update until either convergence criteria is met or
Expand All @@ -760,9 +761,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)

progbar: tqdm.tqdm
Progress bar handle (default is ``None``)
"""
self._run_alg(max_iter)
self._run_alg(max_iter, progbar)

# retrieve metrics results
self.retrieve_outputs()
Expand Down Expand Up @@ -995,7 +997,7 @@ def _update(self):
or self._cost_func.get_cost(self._x_new)
)

def iterate(self, max_iter=150):
def iterate(self, max_iter=150, progbar=None):
"""Iterate.

This method calls update until either convergence criteria is met or
Expand All @@ -1005,9 +1007,10 @@ def iterate(self, max_iter=150):
----------
max_iter : int, optional
Maximum number of iterations (default is ``150``)

progbar: tqdm.tqdm
Progress bar handle (default is ``None``)
"""
self._run_alg(max_iter)
self._run_alg(max_iter, progbar)

# retrieve metrics results
self.retrieve_outputs()
Expand Down
11 changes: 7 additions & 4 deletions modopt/opt/algorithms/primal_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _update(self):
or self._cost_func.get_cost(self._x_new, self._y_new)
)

def iterate(self, max_iter=150, n_rewightings=1):
def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
"""Iterate.

This method calls update until either convergence criteria is met or
Expand All @@ -237,14 +237,17 @@ def iterate(self, max_iter=150, n_rewightings=1):
Maximum number of iterations (default is ``150``)
n_rewightings : int, optional
Number of reweightings to perform (default is ``1``)

progbar: tqdm.tqdm
Progress bar handle (default is ``None``)
"""
self._run_alg(max_iter)
self._run_alg(max_iter, progbar)

if not isinstance(self._reweight, type(None)):
for _ in range(n_rewightings):
self._reweight.reweight(self._linear.op(self._x_new))
self._run_alg(max_iter)
if progbar:
progbar.reset(total=max_iter)
self._run_alg(max_iter, progbar)

# retrieve metrics results
self.retrieve_outputs()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
importlib_metadata>=3.7.0
numpy>=1.19.5
scipy>=1.5.4
progressbar2>=3.53.1
tqdm>=4.64.0