From f4d378d70d1633fa0044020c26c32752d96670a0 Mon Sep 17 00:00:00 2001 From: jstarck Date: Fri, 2 May 2025 13:36:02 +0300 Subject: [PATCH 1/3] starlet on the sphere --- pycs/sparsity/mrs/mrs_starlet.py | 343 +++++++++++++++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 pycs/sparsity/mrs/mrs_starlet.py diff --git a/pycs/sparsity/mrs/mrs_starlet.py b/pycs/sparsity/mrs/mrs_starlet.py new file mode 100644 index 0000000..8f5ad5d --- /dev/null +++ b/pycs/sparsity/mrs/mrs_starlet.py @@ -0,0 +1,343 @@ +import numpy as np +import random + +import os, sys +from scipy import ndimage +import healpy as hp +from astropy.io import fits +import matplotlib.pyplot as plt +from astropy.io import fits +from importlib import reload +from pycs.misc.cosmostat_init import * +from pycs.misc.mr_prog import * +from pycs.sparsity.mrs.mrs_tools import * + +def mrs_starlet(map, nscale=None, lmax=None): + nside = gnside(map) + if nscale is None: + Ns = np.int64(np.log2(nside) - 2) + else: + Ns = nscale + + npix = map.shape[0] + w = wt_trans(map, lmax=lmax,nscales=Ns-1) + trans = w.T + return trans + +def mrs_istarlet(trans): + r = np.sum(trans, axis=0) + return r + + +def mrs_uwttrans(map, nscale=None, lmax=None, opt=None, verbose=False, path="./", progpath=None, cxx=False): + nside = gnside(map) + if nscale is None: + Ns = np.log2(nside) - 2 + else: + Ns = nscale + + if cxx: + optParam = " " + if opt is not None: + optParam = " " + opt + if lmax is not None: + optParam = " -l " + str(lmax) + optParam + if nscale is not None: + optParam = " -n " + str(nscale) + optParam + if progpath is None: + prog = "mrs_uwttrans" + else: + prog = progpath + "mrs_uwttrans" + p = mrs_prog( + map, + prog=prog, + verbose=verbose, + opt=optParam, + OutputFormatisHealpix=False, + path=path, + ) + else: + npix = map.shape[0] + w = wt_trans(map, lmax=lmax,nscales=Ns-1) + p = np.zeros(Ns, npix) + for j in range(Ns): + print(j+1) + p[j,:] = w[:,j] + + return p + + +def mrs_uwtrecons(Tmap, lmax=None, opt=None, verbose=False, path="./", progpath=None): + optParam = " " + if opt is not None: + optParam = " " + opt + if lmax is not None: + optParam = " -l " + str(lmax) + optParam + if progpath is None: + prog = "mrs_uwttrans" + else: + prog = progpath + "mrs_uwttrans -r " + p = mrs_prog( + Tmap, + prog=prog, + verbose=verbose, + opt=optParam, + InputFormatisHealpix=False, + OutputFormatisHealpix=True, + path=path, + ) + return p + + + +# Wavelet filtering + +def spline2(size, l, lc): + """ + Compute a non-negative decreasing spline, with value 1 at index 0. + + Parameters + ---------- + size: int + size of the spline + l: float + spline parameter + lc: float + spline parameter + + Returns + ------- + np.ndarray + (size,) float array, spline + """ + + res = np.arange(0, size+1) + res = 2*l*res/(lc*size) + res = (3/2) * 1/12 * (abs(res-2)**3 - 4*abs(res-1)**3 + 6*abs(res)**3 - 4*abs(res+1)**3 + abs(res+2)**3) + return res + + +def compute_h(size, lc): + """ + Compute a low-pass filter. + + Parameters + ---------- + size: int + size of the filter + lc: float + cutoff parameter + + Returns + ------- + np.ndarray + (size,) float array, filter + """ + + tab1 = spline2(size, 2*lc, 1) + tab2 = spline2(size, lc, 1) + h = tab1/(tab2+1e-6) + h[np.int64(size/(2*lc)):size] = 0 + return h + + +def compute_g(size, lc): + """ + Compute a high-pass filter. + + Parameters + ---------- + size: int + size of the filter + lc: float + cutoff parameter + + Returns + ------- + np.ndarray + (size,) float array, filter + """ + + tab1 = spline2(size, 2*lc, 1) + tab2 = spline2(size, lc, 1) + g = (tab2-tab1)/(tab2+1e-6) + g[np.int64(size/(2*lc)):size] = 1 + return g + + +def get_wt_filters(lmax, nscales): + """Compute wavelet filters. + + Parameters + ---------- + lmax: int + maximum l + nscales: int + number of wavelet detail scales + + Returns + ------- + np.ndarray + (lmax+1,nscales+1) float array, filters + """ + + wt_filters = np.ones((lmax+1, nscales+1)) + wt_filters[:, 1:] = np.array([compute_h(lmax, 2**scale) for scale in range(nscales)]).T + wt_filters[:, :nscales] -= wt_filters[:, 1:(nscales+1)] + return wt_filters + + +def wt_trans(inputs, nscales=3, lmax=None, alm_in=False, nside=None, alm_out=False): + """Wavelet transform an array. + + Parameters + ---------- + inputs: np.ndarray + (p,) or (n,p) float array, map or stack of n maps / if alm_in, (t,) or (n,t) complex array, alm or stack + of n alms + nscales: int + number of wavelet detail scales + lmax: int + maximum l (default: 3*nside / if alm_in, deduced from inputs) + alm_in: bool + inputs is alm + nside: int + nside of the output Healpix maps (default: deduced from maps) + alm_out: bool + output is alm + + Returns + ------- + np.ndarray + (p,nscales+1) or (n,p,scales+1) float array, wavelet transform of the input array or stack of the wavelet + transforms of the n input arrays / if alm_out, (t,nscales+1) or (n,t,scales+1) complex array, alm of the + wavelet transform of the input array or stack of the alms of the wavelet transforms of the n input arrays + """ + dim_inputs = len(np.shape(inputs)) + maps = None # to remove warnings + + if alm_in: + alms = inputs + if nside is None and not alm_out: + raise ValueError("nside is missing") + if not alm_out: + maps = alm2map(alms, nside) + if lmax is None: + lmax = hp.Alm.getlmax(np.shape(alms)[-1]) + + else: + maps = inputs + if dim_inputs == 1: + nside = hp.get_nside(maps) + else: + nside = hp.get_nside(maps[0, :]) + if lmax is None: + lmax = 3 * nside + alms = map2alm(maps, lmax=lmax) + + if not alm_out: + l_scale = maps.copy() + if dim_inputs == 1: + npix = len(maps) + wts = np.zeros((npix, nscales + 1)) + else: + npix = np.shape(maps)[1] + wts = np.zeros((np.shape(maps)[0], npix, nscales + 1)) + else: + l_scale = alms.copy() + if dim_inputs == 1: + npix = np.size(alms) + wts = np.zeros((npix, nscales + 1), dtype='complex') + else: + npix = np.shape(alms)[1] + wts = np.zeros((np.shape(maps)[0], npix, nscales + 1), dtype='complex') + + scale = 1 + for j in range(nscales): + h = compute_h(lmax, scale) + if not alm_out: + m = alm2map(alm_product(alms, h), nside) + else: + m = alm_product(alms, h) + h_scale = l_scale - m + l_scale = m + if dim_inputs == 1: + wts[:, j] = h_scale + else: + wts[:, :, j] = h_scale + scale *= 2 + + if dim_inputs == 1: + wts[:, nscales] = l_scale + else: + wts[:, :, nscales] = l_scale + + return wts + + +def wt_rec(wts): + """Reconstruct a wavelet decomposition. + + Parameters + ---------- + wts: np.ndarray + (p,nscales+1) or (n,p,scales+1) float array, wavelet transform of a map or stack of the wavelet transforms of n + maps + + Returns + ------- + np.ndarray + (p,) or (n,p,) float array, reconstructed map or stack of n reconstructed maps + """ + + return np.sum(wts, axis=-1) + + +# Plots + +def mrs_tv(maps, log=False, unit='', title='', minimum=None, maximum=None, cbar=True): + """Plot one or more Healpix maps in Mollweide projection. + + Parameters + ---------- + maps: np.ndarray + (p,) or (n,p) float array, map or stack of n maps + log: bool + logarithmic scale + unit: str + unit of the data + title: str + title of the plots + minimum: float + minimum range value (default: min(maps, maps2)) + maximum: float + maximum range value (default: max(maps, maps2)) + cbar: bool + show color bar + + Returns + ------- + None + """ + + if len(np.shape(maps)) == 1: + maps = np.expand_dims(maps, axis=0) + + if minimum is None: + minimum = np.min(maps) + + if maximum is None: + maximum = np.max(maps) + + if not log: + def f(x): return x + else: + def f(x): return np.log10(x - minimum + 1) + for i in range(np.shape(maps)[0]): + if title: + tit = title + ": Scale " + str(i+1) + else: + tit = "Scale " + str(i+1) + hp.mollview(f(maps[i, :]), fig=None, unit=unit, title=tit, min=f(minimum), max=f(maximum), cbar=cbar) + + From 6e9c6732267f706f76bb1b09e9cac1ba407571e3 Mon Sep 17 00:00:00 2001 From: jstarck Date: Fri, 2 May 2025 13:36:27 +0300 Subject: [PATCH 2/3] spherical healpix tools --- pycs/sparsity/mrs/mrs_tools.py | 402 +++++++++++++++++++++++++++++---- 1 file changed, 364 insertions(+), 38 deletions(-) diff --git a/pycs/sparsity/mrs/mrs_tools.py b/pycs/sparsity/mrs/mrs_tools.py index 62a7e65..5f0b33d 100644 --- a/pycs/sparsity/mrs/mrs_tools.py +++ b/pycs/sparsity/mrs/mrs_tools.py @@ -107,7 +107,6 @@ def g2k(g1, g2): ke = hp.alm2map(ae, nside, pol=False) return ke - def k2g(ke): nside = gnside(ke) ae = hp.map2alm(ke, 1, pol=False) @@ -243,44 +242,371 @@ def tol(map, lmax_amin, amin=False): return b -def mrs_uwttrans(map, lmax=None, opt=None, verbose=False, path="./", progpath=None): - optParam = " " - if opt is not None: - optParam = " " + opt - if lmax is not None: - optParam = " -l " + str(lmax) + optParam - if progpath is None: - prog = "mrs_uwttrans" - else: - prog = progpath + "mrs_uwttrans" - p = mrs_prog( - map, - prog=prog, - verbose=verbose, - opt=optParam, - OutputFormatisHealpix=False, - path=path, - ) - return p +# Spherical harmonic transform : Code from Remi Carloni Gertosio +def map2alm(maps, lmax=None, iter=3): + """Computes the alm of a Healpix map. + + Parameters + ---------- + maps: np.ndarray + (p,) or (n,p) float array, map or stack of n maps in Healpix representation + lmax: int + maximum l of the alm (default: 3*nside) + iter: 3 + number of iterations + + Returns + ------- + np.ndarray + (t,) or (n,t) complex array, alm or stack of n alms + """ + + if len(np.shape(maps)) == 1: + if lmax is None: + lmax = 3*hp.get_nside(maps) + return hp.sphtfunc.map2alm(maps, lmax=lmax, iter=iter) + + n = np.shape(maps)[0] + if lmax is None: + lmax = 3*hp.get_nside(maps[0, :]) + return np.array([hp.sphtfunc.map2alm(maps[i, :], lmax=lmax, iter=iter) for i in range(n)]) + + +def alm2map(alms, nside): + """Computes a Healpix map given the alm. + + Parameters + ---------- + alms: np.ndarray + (t,) or (n,t) complex array, alm or stack of n alms + nside: int + nside of the output Healpix maps + + Returns + ------- + np.ndarray + (p,) or (n,p) float array, map or stack of n maps in Healpix representation + """ + + if len(np.shape(alms)) == 1: + return hp.alm2map(alms, nside) + + n = np.shape(alms)[0] + return np.array([hp.sphtfunc.alm2map(alms[i, :], nside) for i in range(n)]) + + +def alm_product(alms, filters): + """Apply an isotropic filter on an alm. + + Parameters + ---------- + alms: np.ndarray + (t,) or (n,t) complex array, alm or stack of n alms + filters: np.ndarray + (lmax+1,) or (n,lmax+1) float array, isotropic filter or stack of n isotropic filters (one filter per source) in + spherical harmonic domain + + Returns + ------- + np.ndarray + (t,) or (n,t) complex array, filtered alm or stack of n filtered alms + """ + + dim_filters = len(np.shape(filters)) + dim_alms = len(np.shape(alms)) + + if dim_filters == 1 and dim_alms == 1: + return hp.sphtfunc.smoothalm(alms, beam_window=filters, inplace=False) + + n = np.shape(alms)[0] + + if dim_filters == 1: + return np.array([hp.sphtfunc.smoothalm(alms[i, :], beam_window=filters, verbose=False, inplace=False) + for i in range(n)]) + + return np.array([hp.sphtfunc.smoothalm(alms[i, :], beam_window=filters[i, :], verbose=False, inplace=False) + for i in range(n)]) + + +def convolve(maps, filters, lmax=None, nside=None): + """Convolve maps with filters. + + Parameters + ---------- + maps: np.ndarray + (p,) or (n,p) float array, map or stack of n maps in Healpix representation + filters: np.ndarray + (lmax+1,) or (n,lmax+1) float array, isotropic filter or stack of n isotropic filters (one filter per source) + lmax: int + maximum l of the filtering (default: deduced from filters) + nside: int + nside of the output Healpix maps (default: deduced from maps) + + Returns + ------- + maps: np.ndarray + (p,) or (n,p) float array, convolved map or stack of n convolved maps in Healpix representation + """ -def mrs_uwtrecons(Tmap, lmax=None, opt=None, verbose=False, path="./", progpath=None): - optParam = " " - if opt is not None: - optParam = " " + opt if lmax is not None: - optParam = " -l " + str(lmax) + optParam - if progpath is None: - prog = "mrs_uwttrans" + if len(np.shape(filters)) == 1: + lmax = len(filters) - 1 + else: + lmax = np.shape(filters)[1] - 1 + + alms = map2alm(maps, lmax=lmax) + + alms = alm_product(alms, filters) + + if nside is None: + nside = hp.get_nside(maps) + + return alm2map(alms, nside=nside) + + +def anafast(maps, lmax=None, iter=3): + """Computes the angular power spectrum of a Healpix map. + + Parameters + ---------- + maps: np.ndarray + (p,) or (n,p) float array, map or stack of n maps in Healpix representation + lmax: int + maximum l of the angular power spectrum (default: 3*nside of maps) + iter: 3 + number of iterations + + Returns + ------- + np.ndarray + (lmax+1,) or (n,lmax+1) float array, angular power spectrum or stack of n angular power spectra + """ + + if len(np.shape(maps)) == 1: + if lmax is None: + lmax = 3*hp.get_nside(maps) + return hp.sphtfunc.anafast(maps, lmax=lmax, iter=iter) + + n = np.shape(maps)[0] + if lmax is None: + lmax = 3 * hp.get_nside(maps[0, :]) + return np.array([hp.sphtfunc.anafast(maps[i, :], lmax=lmax) for i in range(n)]) + + +def alm2cl(alms): + """Computes the angular power spectrum from an alm. + + Parameters + ---------- + alms: np.ndarray + (t,) or (n,t) complex array, alm or stack of n alms + + Returns + ------- + np.ndarray + (lmax+1,) or (n,lmax+1) float array, angular power spectrum or stack of n angular power spectra + """ + + if len(np.shape(alms)) == 1: + return hp.sphtfunc.alm2cl(alms) + + n = np.shape(alms)[0] + return np.array([hp.sphtfunc.alm2cl(alms[i, :]) for i in range(n)]) + + +# Alm index computation + +def getsize(lmax): + """Returns the size of the array needed to store alm up to lmax. + + Parameters + ---------- + lmax: int + maximum l of the alm + + Returns + ------- + int + size of the array needed to store alm up to lmax + + """ + + return hp.Alm.getsize(lmax) + + +def getlm(lmax): + """Get the mapping of an alm. + + Parameters + ---------- + lmax: int + maximum l of the alm + + Returns + ------- + (np.ndarray,np.ndarray) + l to index map, + m to index map + """ + + return hp.Alm.getlm(lmax) + + +def npix2nside(npix): + """ + Give the nside parameter for the given number of pixels. + + Parameters + ---------- + npix: int + number of pixels + + Returns + ------- + nside: int + nside + """ + + return hp.npix2nside(npix) + + +# Plots + +def mollview(maps, maps2=None, log=False, unit='', title='', minimum=None, maximum=None, cbar=True): + """Plot one or more Healpix maps in Mollweide projection. + + Parameters + ---------- + maps: np.ndarray + (p,) or (n,p) float array, map or stack of n maps + maps2: np.ndarray + (p,) or (n,p) float array, second map or stack of n maps, optional + log: bool + logarithmic scale + unit: str + unit of the data + title: str + title of the plots + minimum: float + minimum range value (default: min(maps, maps2)) + maximum: float + maximum range value (default: max(maps, maps2)) + cbar: bool + show color bar + + Returns + ------- + None + """ + + if len(np.shape(maps)) == 1: + maps = np.expand_dims(maps, axis=0) + if maps2 is not None: + maps2 = np.expand_dims(maps2, axis=0) + if minimum is None: + minimum = np.min(maps) + if maps2 is not None: + minimum = np.min([minimum, np.min(maps2)]) + if maximum is None: + maximum = np.max(maps) + if maps2 is not None: + maximum = np.max([maximum, np.max(maps2)]) + if not log: + def f(x): return x else: - prog = progpath + "mrs_uwttrans -r " - p = mrs_prog( - Tmap, - prog=prog, - verbose=verbose, - opt=optParam, - InputFormatisHealpix=False, - OutputFormatisHealpix=True, - path=path, - ) - return p + def f(x): return np.log10(x - minimum + 1) + for i in range(np.shape(maps)[0]): + hp.mollview(f(maps[i, :]), fig=None, unit=unit, title=title, min=f(minimum), max=f(maximum), cbar=cbar) + if maps2 is not None: + hp.mollview(f(maps2[i, :]), fig=None, unit=unit, title=title, min=f(minimum), max=f(maximum), cbar=cbar) + +def view_spec(inputs, lmax=None, alm_in=False): + """Plot the angular power spectrum of one or several maps. + + Parameters + ---------- + inputs: np.ndarray + (p,) or (n,p) float array, map or stack of n maps / if alm_in, (t,) or (n,t) complex array, alm or stack + of n alms + lmax: int + maximum l (default: 3*nside / if alm_in, deduced from inputs) + alm_in: bool + inputs is alm + + Returns + ------- + None + """ + + if len(np.shape(inputs)) == 1: + inputs = np.expand_dims(inputs, axis=0) + + if not alm_in: + cls = anafast(inputs, lmax=lmax) + else: + cls = alm2cl(inputs) + + plt.figure() + for i in range(np.shape(inputs)[0]): + plt.semilogy(cls[i, :], label='Source '+str(i+1)) + plt.xlabel('$l$') + plt.ylabel('$c_l$') + if np.shape(inputs)[0] != 1: + plt.legend() + +# Miscellaneous + +def getidealbeam(lmax, cutmin=None, cutmax=None): + """Compute a beam, with value 1 until a first cutoff frequency and 0 after a second cutoff frequency. The transition + is computed with a spline. + + Parameters + ---------- + lmax: int + maximum l + cutmin: int + frequency below which filter is 1 (default: int((lmax+1)/4)) + cutmax: int + frequency above which filter is 0 (default: int((lmax+1)/2)) + + Returns + ------- + np.ndarray + (lmax+1,) float array, filter + """ + + if cutmin is None: + cutmin = np.int64((lmax+1)/4) + if cutmax is None: + cutmax = np.int64((lmax+1)/2) + bl = np.zeros(lmax+1) + bl[0:cutmin] = 1 + bl[cutmin:cutmax] = spline2(cutmax-cutmin-1, 1, 1) + return bl + + +def getbeam(fwhm=100, lmax=512): + """Get a spherical Gaussian-shaped beam. + + Parameters + ---------- + fwhm: float + full width at half maximum in the harmonic space (in terms of l) + lmax: int + maximum l + + Returns + ------- + np.ndarray + (lmax+1,) float array, Gaussian-shaped beam + """ + + tor = 0.0174533 + if len(np.shape(fwhm)) == 1: + fwhm = np.expand_dims(fwhm, axis=1) + F = fwhm / 60 * tor + l = np.arange(0, lmax+1) + ell = l * (l + 1) + bl = np.exp(-ell * F * F / 16 / np.log(2)) + return bl From 60cf2b2e684ecf364fbaf3f6b46a2d8d939de21c Mon Sep 17 00:00:00 2001 From: jstarck Date: Fri, 2 May 2025 13:37:27 +0300 Subject: [PATCH 3/3] remove deprecated import imp --- pycs/sparsity/sparse2d/starlet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycs/sparsity/sparse2d/starlet.py b/pycs/sparsity/sparse2d/starlet.py index 38f342c..1fe37e4 100644 --- a/pycs/sparsity/sparse2d/starlet.py +++ b/pycs/sparsity/sparse2d/starlet.py @@ -40,7 +40,7 @@ import sys -import imp +# import imp PYSAP_CXX = True try: