Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,14 @@ @article{zou2005
journal = {Journal of the Royal Statistical Society Series B},
doi = {10.1111/j.1467-9868.2005.00527.x}
}

@article{Ryu_Liu_Wang_Chen_Wang_Yin_2019,
title={Plug-and-Play Methods Provably Converge with Properly Trained Denoisers},
url={http://arxiv.org/abs/1905.05406},
DOI={10.48550/arXiv.1905.05406},
number={arXiv:1905.05406},
institution={arXiv},
author={Ryu, Ernest K. and Liu, Jialin and Wang, Sicheng and Chen, Xiaohan and Wang, Zhangyang and Yin, Wotao},
year={2019},
month={May}
}
148 changes: 148 additions & 0 deletions modopt/opt/algorithms/pnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
r"""Plug 'n play algorithms.

These algorithms solves

..math :: \mathrm{arg}\min_{x\in \mathbb{R}^d} f(x) + g(x)

"""

from modopt.opt.algorithms import SetUp


class PnpADMM(SetUp):
r"""Plug 'n Play ADMM.

Implements (PNP-ADMM) of :cite:`ryu2019` to solve

..math :: \arg\min f(x) + g(x)

Parameters
----------
x: array
Initial Value
proxf: Operator
Data consistency function as a proximal operator.
proxg: Operator
Regularisation function (e.g. denoiser).
alpha: float, default 1
Data consistency parameter, analoguous to gradient step size.
sigma: float, default 1
Noise level parameter.
"""

def __init__(
self,
x_init,
proxf,
proxg,
alpha=1,
sigma=1,
**kwargs,
):
super().__init__(**kwargs)

# init iteration variables.
self._x_old = self.xp.copy(x_init)
self._x_new = self.xp.copy(x_init)
self._y_old = self.xp.copy(x_init)
self._y_new = self.xp.copy(x_init)
self._u_old = self.xp.copy(x_init)
self._u_new = self.xp.copy(x_init)

# algorithm parameters
self._proxf = proxf
self._proxg = proxg
self._alpha = alpha
self._sigma = sigma

def _update(self):
self._x_new = self._proxf.op(
self._y_old - self._u_old,
extra_factor=self._alpha,
)
self._y_new = self._proxg.op(self._x_new - self._u_old)
self._u_new = self._u_old + self._x_new - self._y_new

# Update iteration
self.xp.copyto(self._x_old, self._x_new)
self.xp.copyto(self._y_old, self._y_new)
self.xp.copyto(self._u_old, self._u_new)


class PnpFBS(SetUp):
"""Plug'n Play Forward Backward Splitting.

Implements (PNP-FBS) of :cite:`ryu2019`

Parameters
----------
x: array
Initial estimation
gradf: Operator
Gradient of :math:`f`
proxg: Operator
Proximal operator or plug-in replacement for g
alpha: float
gradient descent step size
sigma: float
Noise level.
"""

def __init__(
self,
x_init,
gradf,
proxg,
alpha=1,
sigma=1,
**kwargs,
):
super().__init__(**kwargs)

# init iteration variables.
self._x_old = self.xp.copy(x_init)
self._x_new = self.xp.copy(x_init)

# algorithm parameters
self._gradf = gradf
self._proxg = proxg
self._alpha = alpha
self._sigma = sigma

def _update(self):
self._gradf.get_grad(self._x_old)
# saves an (possibly expensive) array allocation
self._x_new = (-self._alpha) * self._gradf.grad
self._x_new += self._x_old
self._x_new = self._proxg.op(self._x_new)

# Update iteration
self.xp.copyto(self._x_old, self._x_new)

def get_notify_observers_kwargs(self):
"""Notify observers.

Return the mapping between the metrics call and the iterated
variables.

Returns
-------
dict
The mapping between the iterated variables
"""
return {
'x_new': self._x_new,
'idx': self.idx,
}

def retrieve_outputs(self):
"""Retrieve outputs.

Declare the outputs of the algorithms as attributes: x_final,
y_final, metrics.

"""
metrics = {}
for obs in self._observers['cv_metrics']:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
35 changes: 35 additions & 0 deletions modopt/opt/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,41 @@ def cost(self, method):
self._cost = check_callable(method)


def grad2prox(grad_op, step):
r"""Generate a proximity operator from a gradient operator.

This use a first order approximation (See :ref:`Notes`)

Parameters
----------
grad_op: GradBase
Gradient operator
step: float
Gradient descent step.

Returns
-------
ProximityParent:
Proximal operator performing gradient descent step.

Notes
-----
Let :math:`f` a differentiable function. Its proximal operator is:

..math :: prox_{\lambda f}(x)=(I+\lambda\nabla f)^{-1}(x)

And a first order approximation yields

..math :: prox_{\lambda f}(x)= (I - \lambda\nabla f)(x)
"""

def _op(input_data, extra_factor=step):
grad_op.get_grad(input_data)
return input_data - extra_factor * grad_op.grad

return ProximityParent(_op, grad_op.cost)


class IdentityProx(ProximityParent):
"""Identity Proxmity Operator.

Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ per-file-ignores =
# - Rethink subscript slice assignment
# - Reduce complexity of KSupportNorm
# - Check bitwise operations
modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508
# - Nested function declaration
modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508, WPS430
#Todo: Consider changing cwbReweight name
modopt/opt/reweight.py: N801
#Justification: Needed to import matplotlib.pyplot
Expand Down