diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py index 85c36306..b236013b 100644 --- a/modopt/opt/algorithms/base.py +++ b/modopt/opt/algorithms/base.py @@ -4,7 +4,7 @@ from inspect import getmro import numpy as np -from progressbar import ProgressBar +from tqdm.auto import tqdm from modopt.base import backend from modopt.base.observable import MetricObserver, Observable @@ -15,7 +15,7 @@ class SetUp(Observable): r"""Algorithm Set-Up. This class contains methods for checking the set-up of an optimisation - algotithm and produces warnings if they do not comply. + algorithm and produces warnings if they do not comply. Parameters ---------- @@ -38,7 +38,6 @@ class SetUp(Observable): -------- modopt.base.observable.Observable : parent class modopt.base.observable.MetricObserver : definition of metrics - """ def __init__( @@ -240,9 +239,8 @@ def _iterations(self, max_iter, progbar=None): ---------- max_iter : int Maximum number of iterations - progbar : progressbar.bar.ProgressBar - Progress bar (default is ``None``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ for idx in range(max_iter): self.idx = idx @@ -268,10 +266,10 @@ def _iterations(self, max_iter, progbar=None): print(' - Converged!') break - if not isinstance(progbar, type(None)): - progbar.update(idx) + if progbar: + progbar.update() - def _run_alg(self, max_iter): + def _run_alg(self, max_iter, progbar=None): """Run algorithm. Run the update step of a given algorithm up to the maximum number of @@ -281,17 +279,34 @@ def _run_alg(self, max_iter): ---------- max_iter : int Maximum number of iterations + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) See Also -------- - progressbar.bar.ProgressBar + tqdm.tqdm """ - if self.progress: - with ProgressBar( - redirect_stdout=True, - max_value=max_iter, - ) as progbar: - self._iterations(max_iter, progbar=progbar) + if self.progress and progbar is None: + with tqdm(total=max_iter) as pb: + self._iterations(max_iter, progbar=pb) + elif progbar: + self._iterations(max_iter, progbar=progbar) else: self._iterations(max_iter) + + def _update(self): + raise NotImplementedError + + def get_notify_observers_kwargs(self): + """Notify Observers. + + Return the mapping between the metrics call and the iterated + variables. + + Raises + ------ + NotImplementedError + This method should be overriden by subclasses. + """ + raise NotImplementedError diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py index e18f66c3..6923a6af 100644 --- a/modopt/opt/algorithms/forward_backward.py +++ b/modopt/opt/algorithms/forward_backward.py @@ -467,7 +467,7 @@ def _update(self): or self._cost_func.get_cost(self._x_new) ) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either the convergence criteria is met @@ -477,9 +477,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -750,7 +751,7 @@ def _update(self): if self._cost_func: self.converge = self._cost_func.get_cost(self._x_new) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -760,9 +761,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -995,7 +997,7 @@ def _update(self): or self._cost_func.get_cost(self._x_new) ) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -1005,9 +1007,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py index c8566969..d5bdd431 100644 --- a/modopt/opt/algorithms/primal_dual.py +++ b/modopt/opt/algorithms/primal_dual.py @@ -225,7 +225,7 @@ def _update(self): or self._cost_func.get_cost(self._x_new, self._y_new) ) - def iterate(self, max_iter=150, n_rewightings=1): + def iterate(self, max_iter=150, n_rewightings=1, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -237,14 +237,17 @@ def iterate(self, max_iter=150, n_rewightings=1): Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) if not isinstance(self._reweight, type(None)): for _ in range(n_rewightings): self._reweight.reweight(self._linear.op(self._x_new)) - self._run_alg(max_iter) + if progbar: + progbar.reset(total=max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() diff --git a/requirements.txt b/requirements.txt index 63a404ba..1f44de13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ importlib_metadata>=3.7.0 numpy>=1.19.5 scipy>=1.5.4 -progressbar2>=3.53.1 +tqdm>=4.64.0