Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 116 additions & 69 deletions pycs/sparsity/mrs/mrs_starlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,61 @@
Related Geometric Multiscale Analysis,
Cambridge University Press, Cambridge (GB), 2016.

Example how to use the Class:
Example how to use the Class with 5 scales i.e. 4 wavelet scales + coarse resolution):
CW = MRS_starlet() # Create the class
CW.init_starlet(Nside, nscale=5)
CW.transform(Image) # Starlet transform of a 2D np array
CW.stat() # print statistics of all scales
r = CW.recons() # reconstruct an image from its coefficients
CW.plot_filter() # plot the filters in harmonic space which are used in the wavelet decomposition
more examples are given at the end of this file.

Class variables are:
nx = 0 # number of pixel of the healpix map
ns = 0 # number of scales
coef = 0.0 # Starlet coefficients
TabNorm = 0.0 # Coefficient normalixation table
SigmaNoise = 1.0 # noise standard deviation
TabNsigma = 0 # detection level per scale
verb = False
nside = 0 # nside of the input map
lmax = 0 # lmax used in spherical harmonic decomposition
ALM_iter = 0 # numnber of iteration for the inverse spherical harmonic decomposition
TabNameCode = ["Full python", "c++ Binding", "c++ binary"]
TypeCode = 0 # 0 for 'Full python', '1' for 'c++ Binding' and 2 for'c++ binary'
Tablmax =0 # lmax for each scale
TabSigma =0 # Standard deviation of the Gaussian which fits the scaling function at every scale
TabPhi = 0 # Scaling function for each scale
TabPsi = 0 # Wavelet function for each scale
Tabh = 0 # h filter for each scale
Tabg = 0 # g filter for each scale
TabResol = 0 # Resolution of each wavelet scale in arc minute
PixelResol = 0 # pixel sizr in arc minute
l2norm = False # if True, normlaize the coefficients (l2 normalization) such that the noise standart deviation remains constant through the scales.
NeedletFilter = False # If True, use needlet filters instead of spline filters

Class functions are:
def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None, Needlet=None):
def info(self): # Print information relative to the intialisation.
def stat(self): # Print Min, Max, Mean and standard deviation of all scales.
def plot_filter(self, wavelet=True, scaling=False, hfilter=False, gfilter=False): # plot the filters which are used.
def transform(self, data, WTname=None, opt=None): # Compute the wavelet transform
WaveletScale = def get_scale(self, j): # returns the wavelet coefficients at a given scale : return self.coef[j,:]
def put_scale(self, ScaleCoef, j): # insert a scale in the wavelet transform: self.coef[j,:] = ScaleCoef
Rec = recons(self): reconstruct an image from its wavelet coefficients
DenoiseMap = denoising(self, data, SigmaNoise=0, Nsigma=3, ThresCoarse=False, hard=True): # perform a denoising in wavelet space

def tvs(self,j,min=None,max=None,title=None,sigma=None,lut=None,filename=None, dpi=100) : plot the scale j
def tv(self, log=False, unit="", title="", minimum=None, maximum=None, cbar=True): plot all wavelet scales
def dump(self) : print all variable values of the class
SigmaNoise = get_noise(self) : estimate the noise standard deviation from the first wavelet scale
TabNsigma = get_tabsigma(self, Nsigma=3) # Cretate the table for the detection level. Nsigma can be either a number of an array.
def threshold(self,SigmaNoise=0,Nsigma=3,ThresCoarse=False,hard=True,FirstDetectScale=0,KillCoarse=False,Verbose=False): # Threshold the wavelet coefficient
CopiedClass = copy(self, name="wt"): # Return a copy of the class
def eval_computation_time(self): # Compare the computation time of different implementations of the wavelet transform
"""

import importlib.util

spec = importlib.util.find_spec("pymrs")
if spec is None:
# print("pymrs is available at:", spec.origin)
MRS_CXX = False
else:
pymrs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(pymrs)
MRS_CXX = True


import numpy as np
import random
import os, sys
from scipy import ndimage
import healpy as hp
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.io import fits
from importlib import reload
from pycs.misc.cosmostat_init import *
from pycs.misc.stats import *
from pycs.misc.mr_prog import *
from pycs.sparsity.mrs.mrs_tools import *
import getpass
import time
from scipy.optimize import curve_fit




Expand All @@ -67,13 +86,14 @@ def test_mrs_class(Init=None):
else:
Nside = 1024
d = np.random.normal(size=(Nside**2 * 12))
Ns = 5
Ns = 7
ALM_iter = 0
Needlet=False
lmax=2048
C = CMRStarlet()
C.init_starlet(Nside, nscale=Ns, ALM_iter=ALM_iter)
C.init_starlet(Nside, nscale=Ns, ALM_iter=ALM_iter, lmax=lmax, Needlet=Needlet)
print("PYTHON Code computation time:")
start = time.time()
C.TypeCode=1
C.transform(d)
end = time.time()
print(f"Execution time python: {end - start:.4f} seconds")
Expand Down Expand Up @@ -197,7 +217,7 @@ class CMRStarlet:
SigmaNoise = 1.0 # noise standard deviation
TabNsigma = 0 # detection level per scale
verb = False
nside = 0
nside = 0
lmax = 0
ALM_iter = 0
TabNameCode = ["Full python", "c++ Binding", "c++ binary"]
Expand All @@ -211,6 +231,7 @@ class CMRStarlet:
TabResol = 0 # Resolution of each wavelet scale in arc minute
PixelResol = 0 # pixel sizr in arc minute
l2norm = False
NeedletFilter = False # If True, use need filters instead of spline filters

# __init__ is the constructor
def __init__(self, name="wt", verb=False):
Expand All @@ -230,7 +251,7 @@ def __init__(self, name="wt", verb=False):
self.name = name # self.name is an object variable
self.verb = verb

def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None, Needlet=None):
"""
Initialize the scale for a given image size and a number of scales.
Parameters
Expand All @@ -245,6 +266,10 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
-------
None.
"""
if Needlet:
self.NeedletFilter=True
self.TypeCode = 0

self.nside = np.int64(nside)
self.nx = 12 * self.nside * self.nside
if lmax != 0:
Expand All @@ -265,6 +290,8 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
else:
if nscale == 0:
nscale = np.int64(np.log(self.nside) // 1) + 1
self.TabSigma = np.zeros(nscale)

self.ns = np.int64(nscale)

if ALM_iter != 0:
Expand All @@ -274,12 +301,27 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
CMRS = pymrs.MRS()
CMRS.alloc(nside, self.ns, self.lmax, self.ALM_iter, self.verb)

if TabResolSigma is None:
self.Tablmax, self.TabSigma, self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_default_filters(self.nside, self.ns)
# print("Default TabSigma = ", self.TabSigma)
if Needlet is None:
if TabResolSigma is None:
self.Tablmax, self.TabSigma, self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_default_filters(self.nside, self.ns)
# print("Default TabSigma = ", self.TabSigma)
else:
self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_sigmafilters(self.TabSigma, self.lmax, Phi0Spline=False)
else:
self.TabPhi, self.TabPsi, self.Tabh, self.Tabg = get_sigmafilters(self.TabSigma, self.lmax, Phi0Spline=False)

filters = mrs_needlet_filters(self.lmax, NbrScale=self.ns)
self.Tabh = filters["TabFilterH"]
self.Tabg = filters["TabFilterG"]
self.TabPhi = filters["TabPhi"]
self.TabPsi = filters["TabPsi"]
self.Tablmax = np.zeros((self.ns))
lm = self.lmax
for j in range(0,self.ns):
self.Tablmax[j] = lm
lm = lm / 2
self.TabSigma = splinelmax2sigma(self.Tablmax)

# print("ns = ", self.ns, self.TabSigma.shape )

self.TabResol = np.zeros(self.ns)
self.TabResol[0] = self.TabSigma[0]
for j in range(1,self.ns-1):
Expand All @@ -295,11 +337,12 @@ def init_starlet(self, nside, nscale=0, lmax=0, ALM_iter=0, TabResolSigma=None):
DeltaPhi0 = np.sqrt(1. - Phi0Lmax**2)
# the alm m around lmax are often not accurate, and 1 if a conservative value, the
# the correct value is most likely between 0.9 and 1.
self.TabNorm[0] = np.sqrt( DeltaPhi0**2 + (sigma_filter(self.TabPsi[:,0], self.nside, lmax=self.lmax, PixelWindow=PixelWindow))**2)
for j in range(1,self.ns):
for j in range(0,self.ns):
self.TabNorm[j] = sigma_filter(self.TabPsi[:,j], self.nside, lmax=self.lmax, PixelWindow=PixelWindow)

def info(self): # sound is a method (a method is a function of an object)
if Needlet is None:
self.TabNorm[0] = np.sqrt( DeltaPhi0**2 + (sigma_filter(self.TabPsi[:,0], self.nside, lmax=self.lmax, PixelWindow=PixelWindow))**2)

def info(self):
"""
Print information relative to the intialisation.
"""
Expand Down Expand Up @@ -377,7 +420,10 @@ def transform(self, data, WTname=None, opt=None):
self.coef = mrs_uwttrans(im, self.ns, self.lmax, opt=opt, verbose=self.verb, path="./", cxx=True)
else:
# print(self.ns, self.TabPhi.shape)
self.coef = wt_phi_filter_trans(im, self.TabPhi)
if self.NeedletFilter is False:
self.coef = wt_phi_filter_trans(im, self.TabPhi)
else:
self.coef = mrs_needlet_transform(im, self.TabPsi)
# self.coef = mrs_uwttrans(im,self.ns,self.lmax,opt=None,verbose=self.verb,path="./",cxx=False)

if self.l2norm is True:
Expand All @@ -400,7 +446,12 @@ def recons(self):
if self.l2norm is True:
for j in range(self.ns):
self.coef[j, :] = self.coef[j, :] * self.TabNorm[j]
return np.sum(self.coef, axis=0)

if self.NeedletFilter is False:
rec = np.sum(self.coef, axis=0)
else:
rec = mrs_needlet_recons(self.coef, self.TabPsi)
return rec

def denoising(self, data, SigmaNoise=0, Nsigma=3, ThresCoarse=False, hard=True):
"""
Expand Down Expand Up @@ -435,6 +486,21 @@ def denoising(self, data, SigmaNoise=0, Nsigma=3, ThresCoarse=False, hard=True):
)
return self.recons()

def get_scale(self, j):
"""
Return the scale j in self.coef
Parameters
----------
j : int
Scale number. It must be in [0:self.ns].
Returns
-------
None.

"""
return self.coef[j, :]


def put_scale(self, ScaleCoef, j):
"""
Replace the scale j in self.coef by the 2D array ScaleCoef.
Expand All @@ -451,17 +517,8 @@ def put_scale(self, ScaleCoef, j):
"""
self.coef[j, :] = ScaleCoef

def tvs(
self,
j,
min=None,
max=None,
title=None,
sigma=None,
lut=None,
filename=None,
dpi=100,
):

def tvs(self,j,min=None,max=None,title=None,sigma=None,lut=None,filename=None, dpi=100):
"""
Display the scale j
Parameters
Expand Down Expand Up @@ -561,16 +618,7 @@ def get_tabsigma(self, Nsigma=3):
TabNsigma = Nsigma[:nscale]
return TabNsigma

def threshold(
self,
SigmaNoise=0,
Nsigma=3,
ThresCoarse=False,
hard=True,
FirstDetectScale=0,
KillCoarse=False,
Verbose=False,
):
def threshold(self,SigmaNoise=0,Nsigma=3,ThresCoarse=False,hard=True,FirstDetectScale=0,KillCoarse=False,Verbose=False):
"""
Apply a hard or a soft thresholding on the coefficients self.coef
Parameters
Expand Down Expand Up @@ -658,6 +706,7 @@ def copy(self, name="wt"):
x = self
x.name = name
x.coef = np.zeros((x.ns, x.nx))
x.coef[:,:] = self.coef[:,:]
x.TabNorm = np.copy(self.TabNorm)
return x

Expand Down Expand Up @@ -928,7 +977,7 @@ def plotsig(T, x=None, title="Spherical wavelet Filters", xlabel="X", ylabel="T[

plt.figure(figsize=(10, 6))

plt.plot(x, T / T.max(), label=f"{legend_prefix}")
plt.plot(x, T , label=f"{legend_prefix}")

plt.title(title)
plt.xlabel(xlabel)
Expand Down Expand Up @@ -1120,8 +1169,6 @@ def test_wt_hfilter_trans():
w1 = wt_trans(d, nscales=4)
return wts



def get_sigma_from_spline(lmax, hfilter=False):
# Range of l values (spherical harmonics degrees)
l_vals = np.arange(0, lmax+1)
Expand Down