Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions modopt/opt/linear/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
except ImportError:
pywt_available = False

ptwt_available = True
try:
import ptwt
import torch
import cupy as cp
except ImportError:
ptwt_available = False


class WaveletConvolve(LinearParent):
"""Wavelet Convolution Class.
Expand Down Expand Up @@ -54,10 +62,56 @@ def __init__(self, filters, method='scipy'):




class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.

This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch).

Parameters
----------
wavelet_name: str
the wavelet name to be used during the decomposition.
shape: tuple[int,...]
Shape of the input data. The shape should be a tuple of length 2 or 3.
It should not contains coils or batch dimension.
nb_scales: int, default 4
the number of scales in the decomposition.
mode: str, default "zero"
Boundary Condition mode
compute_backend: str, "numpy" or "cupy", default "numpy"
Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets.

**kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
"""
def __init__(self,
wavelet_name,
shape,
level=4,
mode="symmetric",
compute_backend="numpy",
**kwargs):

if compute_backend == "cupy" and ptwt_available:
self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode)
elif compute_backend == "numpy" and pywt_available:
self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs)
else:
raise ValueError(f"Compute Backend {compute_backend} not available")


self.op = self.operator.op
self.adj_op = self.operator.adj_op

@property
def coeffs_shape(self):
return self.operator.coeffs_shape

class CPUWaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.

This is a light wrapper around PyWavelet, with multicoil support.

Parameters
Expand Down Expand Up @@ -214,3 +268,210 @@ def _adj_op(self, coeffs):
wavelet=self.wavelet,
mode=self.mode,
)


class TorchWaveletTransform:
"""Wavelet transform using pytorch."""

wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"]

def __init__(
self,
shape: tuple[int, ...],
wavelet: str,
level: int,
mode: str,
):
self.wavelet = wavelet
self.level = level
self.shape = shape
self.mode = mode
self.coeffs_shape = None # will be set after op.

def op(self, data: torch.Tensor) -> list[torch.Tensor]:
"""Apply the wavelet decomposition on.

Parameters
----------
data: torch.Tensor
2D or 3D, real or complex data with last axes matching shape of
the operator.

Returns
-------
list[torch.Tensor]
list of tensor each containing the data of a subband.
"""
if data.shape == self.shape:
data = data[None, ...] # add a batch dimension

if len(self.shape) == 2:
if torch.is_complex(data):
# 2D Complex
data_ = torch.view_as_real(data)
coeffs_ = ptwt.wavedec2(
data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2)
)
# flatten list of tuple of tensors to a list of tensors
coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
torch.view_as_complex(cc.contiguous())
for c in coeffs_[1:]
for cc in c
]

return coeffs
# 2D Real
coeffs_ = ptwt.wavedec2(
data, self.wavelet, level=self.level, mode=self.mode, axes=(-2, -1)
)
return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c]

if torch.is_complex(data):
# 3D Complex
data_ = torch.view_as_real(data)
coeffs_ = ptwt.wavedec3(
data_,
self.wavelet,
level=self.level,
mode=self.mode,
axes=(-4, -3, -2),
)
# flatten list of tuple of tensors to a list of tensors
coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
torch.view_as_complex(cc.contiguous())
for c in coeffs_[1:]
for cc in c.values()
]

return coeffs
# 3D Real
coeffs_ = ptwt.wavedec3(
data, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2, -1)
)
return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()]

def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor:
"""Apply the wavelet recomposition.

Parameters
----------
list[torch.Tensor]
list of tensor each containing the data of a subband.

Returns
-------
data: torch.Tensor
2D or 3D, real or complex data with last axes matching shape of the
operator.

"""
if len(self.shape) == 2:
if torch.is_complex(coeffs[0]):
## 2D Complex ##
# list of tensor to list of tuple of tensor
coeffs = [torch.view_as_real(coeffs[0])] + [
tuple(torch.view_as_real(coeffs[i + k]) for k in range(3))
for i in range(1, len(coeffs) - 2, 3)
]
data = ptwt.waverec2(coeffs, wavelet=self.wavelet, axes=(-3, -2))
return torch.view_as_complex(data.contiguous())
## 2D Real ##
coeffs_ = [coeffs[0]] + [
tuple(coeffs[i + k] for k in range(3))
for i in range(1, len(coeffs) - 2, 3)
]
data = ptwt.waverec2(coeffs_, wavelet=self.wavelet, axes=(-2, -1))
return data

if torch.is_complex(coeffs[0]):
## 3D Complex ##
# list of tensor to list of tuple of tensor
coeffs = [torch.view_as_real(coeffs[0])] + [
{
v: torch.view_as_real(coeffs[i + k])
for k, v in enumerate(self.wavedec3_keys)
}
for i in range(1, len(coeffs) - 6, 7)
]
data = ptwt.waverec3(coeffs, wavelet=self.wavelet, axes=(-4, -3, -2))
return torch.view_as_complex(data.contiguous())
## 3D Real ##
coeffs_ = [coeffs[0]] + [
{v: coeffs[i + k] for k, v in enumerate(self.wavedec3_keys)}
for i in range(1, len(coeffs) - 6, 7)
]
data = ptwt.waverec3(coeffs_, wavelet=self.wavelet, axes=(-3, -2, -1))
return data


class CupyWaveletTransform(LinearParent):
"""Wrapper around torch wavelet transform to be compatible with the Modopt API."""

def __init__(
self,
shape: tuple[int, ...],
wavelet: str,
level: int,
mode: str,
):
self.wavelet = wavelet
self.level = level
self.shape = shape
self.mode = mode

self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
self.coeffs_shape = None # will be set after op

def op(self, data: cp.array) -> cp.ndarray:
"""Define the wavelet operator.

This method returns the input data convolved with the wavelet filter.

Parameters
----------
data: cp.ndarray
input 2D data array.

Returns
-------
coeffs: ndarray
the wavelet coefficients.
"""
data_ = torch.as_tensor(data)
tensor_list = self.operator.op(data_)
# flatten the list of tensor to a cupy array
# this requires an on device copy...
self.coeffs_shape = [c.shape for c in tensor_list]
n_tot_coeffs = np.sum([np.prod(s) for s in self.coeffs_shape])
ret = cp.zeros(n_tot_coeffs, dtype=np.complex64) # FIXME get dtype from torch
start = 0
for t in tensor_list:
stop = start + np.prod(t.shape)
ret[start:stop] = cp.asarray(t.flatten())
start = stop

return ret

def adj_op(self, data: cp.ndarray) -> cp.ndarray:
"""Define the wavelet adjoint operator.

This method returns the reconstructed image.

Parameters
----------
coeffs: cp.ndarray
the wavelet coefficients.

Returns
-------
data: ndarray
the reconstructed data.
"""
start = 0
tensor_list = [None] * len(self.coeffs_shape)
for i, s in enumerate(self.coeffs_shape):
stop = start + np.prod(s)
tensor_list[i] = torch.as_tensor(data[start:stop].reshape(s), device="cuda")
start = stop
ret_tensor = self.operator.adj_op(tensor_list)
return cp.from_dlpack(ret_tensor)
6 changes: 5 additions & 1 deletion modopt/opt/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
else:
import_sklearn = True

from modopt.base.backend import get_array_module
from modopt.base.transform import cube2matrix, matrix2cube
from modopt.base.types import check_callable
from modopt.interface.errors import warn
Expand Down Expand Up @@ -215,7 +216,10 @@ def _cost_method(self, *args, **kwargs):
Sparsity cost component

"""
cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0])))
xp = get_array_module(args[0])
cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0])))
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()

if 'verbose' in kwargs and kwargs['verbose']:
print(' - L1 NORM (X):', cost_val)
Expand Down
16 changes: 13 additions & 3 deletions modopt/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
except ImportError:
SKLEARN_AVAILABLE = False

PTWT_AVAILABLE = True
try:
import ptwt
import cupy
except ImportError:
PTWT_AVAILABLE = False

PYWT_AVAILABLE = True
try:
import pywt
Expand Down Expand Up @@ -174,8 +181,12 @@ def case_linear_wavelet_convolve(self):

return linop, data_op, data_adj_op, res_op, res_adj_op

@pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")
def case_linear_wavelet_transform(self):
@parametrize(
compute_backend=[
pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")),
pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."))
])
def case_linear_wavelet_transform(self, compute_backend="numpy"):
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
Expand Down Expand Up @@ -298,7 +309,6 @@ class ProxCases:
[11.67394789, 12.87497954, 14.07601119],
[15.27704284, 16.47807449, 17.67910614],
],
]
)
array233_3 = np.array(
[
Expand Down