From 28a374bc547d5b134fdf184b3072de4c3cc49ee7 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 15 Aug 2023 17:41:19 -0400 Subject: [PATCH 01/13] DEV: make sure all priors return float when needed --- bilby/core/prior/analytical.py | 95 ++++++++++++++------------------ bilby/core/prior/base.py | 7 ++- bilby/core/prior/dict.py | 39 ++++++++----- bilby/core/prior/interpolated.py | 8 +-- bilby/core/utils/calculus.py | 25 ++++++++- bilby/gw/prior.py | 27 ++------- test/core/prior/prior_test.py | 9 +++ 7 files changed, 111 insertions(+), 99 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 5e7b3099f..54148d1b1 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -158,8 +158,10 @@ def cdf(self, val): _cdf = (np.log(val / self.minimum) / np.log(self.maximum / self.minimum)) else: - _cdf = np.atleast_1d(val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / \ - (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + _cdf = ( + (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + ) _cdf = np.minimum(_cdf, 1) _cdf = np.maximum(_cdf, 0) return _cdf @@ -356,16 +358,16 @@ def ln_prob(self, val): return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) def cdf(self, val): - val = np.atleast_1d(val) norm = 0.5 / np.log(self.maximum / self.minimum) - cdf = np.zeros((len(val))) - lower_indices = np.where(np.logical_and(-self.maximum <= val, val <= -self.minimum))[0] - upper_indices = np.where(np.logical_and(self.minimum <= val, val <= self.maximum))[0] - cdf[lower_indices] = -norm * np.log(-val[lower_indices] / self.maximum) - cdf[np.where(np.logical_and(-self.minimum < val, val < self.minimum))] = 0.5 - cdf[upper_indices] = 0.5 + norm * np.log(val[upper_indices] / self.minimum) - cdf[np.where(self.maximum < val)] = 1 - return cdf + _cdf = ( + -norm * np.log(abs(val) / self.maximum) + * (val <= -self.minimum) * (val >= -self.maximum) + + (0.5 + norm * np.log(abs(val) / self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 0.5 * (val >= -self.minimum) * (val <= self.minimum) + + 1 * (val > self.maximum) + ) + return _cdf class Cosine(Prior): @@ -415,10 +417,12 @@ def prob(self, val): return np.cos(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = np.atleast_1d((np.sin(val) - np.sin(self.minimum)) / - (np.sin(self.maximum) - np.sin(self.minimum))) - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + (np.sin(val) - np.sin(self.minimum)) + / (np.sin(self.maximum) - np.sin(self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -469,10 +473,12 @@ def prob(self, val): return np.sin(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = np.atleast_1d((np.cos(val) - np.cos(self.minimum)) / - (np.cos(self.maximum) - np.cos(self.minimum))) - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + (np.cos(val) - np.cos(self.minimum)) + / (np.cos(self.maximum) - np.cos(self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -614,11 +620,13 @@ def prob(self, val): / self.sigma / self.normalisation * self.is_in_prior_range(val) def cdf(self, val): - val = np.atleast_1d(val) - _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( - (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + ( + erf((val - self.mu) / 2 ** 0.5 / self.sigma) + - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) + ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -1354,6 +1362,8 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") + self.expr = np.exp(self.r) + def rescale(self, val): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1371,21 +1381,8 @@ def rescale(self, val): .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 `_, 2017. """ - inv = (-np.exp(-1. * self.r) + (1. + np.exp(self.r)) ** -val + - np.exp(-1. * self.r) * (1. + np.exp(self.r)) ** -val) - - # if val is 1 this will cause inv to be negative (due to numerical - # issues), so return np.inf - if isinstance(val, (float, int)): - if inv < 0: - return np.inf - else: - return -self.sigma * np.log(inv) - else: - idx = inv >= 0. - tmpinv = np.inf * np.ones(len(np.atleast_1d(val))) - tmpinv[idx] = -self.sigma * np.log(inv[idx]) - return tmpinv + inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr + return -self.sigma * np.log(np.maximum(inv, 0)) def prob(self, val): """Return the prior probability of val. @@ -1398,7 +1395,11 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return ( + (np.exp((val - self.mu) / self.sigma) + 1)**-1 + / (self.sigma * np.log1p(self.expr)) + * (val >= self.minimum) + ) def ln_prob(self, val): """Return the log prior probability of val. @@ -1411,19 +1412,7 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - - norm = -np.log(self.sigma * np.log(1. + np.exp(self.r))) - if isinstance(val, (float, int)): - if val < self.minimum: - return -np.inf - else: - return norm - np.logaddexp((val / self.sigma) - self.r, 0.) - else: - val = np.atleast_1d(val) - lnp = -np.inf * np.ones(len(val)) - idx = val >= self.minimum - lnp[idx] = norm - np.logaddexp((val[idx] / self.sigma) - self.r, 0.) - return lnp + return np.log(self.prob(val)) class Categorical(Prior): diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 4ac924f74..3fdad8289 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -5,7 +5,6 @@ import numpy as np import scipy.stats -from scipy.interpolate import interp1d from ..utils import ( infer_args_from_method, @@ -13,6 +12,7 @@ decode_bilby_json, logger, get_dict_with_properties, + WrappedInterp1d as interp1d, ) @@ -178,7 +178,10 @@ def cdf(self, val): cdf = cumtrapz(pdf, x, initial=0) interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, fill_value=(0, 1)) - return interp(val) + output = interp(val) + if isinstance(val, (int, float)): + output = float(output) + return output def ln_prob(self, val): """Return the prior ln probability of val, this should be overwritten diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index d888e4b8f..55ec91732 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -471,7 +471,9 @@ def sample_subset_constrained(self, keys=iter([]), size=None): def normalize_constraint_factor( self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10 ): - if keys in self._cached_normalizations.keys(): + if len(self.constraint_keys) == 0: + return 1 + elif keys in self._cached_normalizations.keys(): return self._cached_normalizations[keys] else: factor_estimates = [ @@ -533,8 +535,10 @@ def check_prob(self, sample, prob): return 0.0 else: constrained_prob = np.zeros_like(prob) - keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_prob[keep] = prob[keep] * ratio + in_bounds = np.isfinite(prob) + subsample = {key: sample[key][in_bounds] for key in sample} + keep = np.array(self.evaluate_constraints(subsample), dtype=bool) + constrained_prob[in_bounds] = prob[in_bounds] * keep * ratio return constrained_prob def ln_prob(self, sample, axis=None): @@ -568,8 +572,10 @@ def check_ln_prob(self, sample, ln_prob): return -np.inf else: constrained_ln_prob = -np.inf * np.ones_like(ln_prob) - keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) + in_bounds = np.isfinite(ln_prob) + subsample = {key: sample[key][in_bounds] for key in sample} + keep = np.log(np.array(self.evaluate_constraints(subsample), dtype=bool)) + constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) return constrained_ln_prob def cdf(self, sample): @@ -603,10 +609,8 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - from matplotlib.cbook import flatten - return list( - flatten([self[key].rescale(sample) for key, sample in zip(keys, theta)]) + [self[key].rescale(sample) for key, sample in zip(keys, theta)] ) def test_redundancy(self, key, disable_logging=False): @@ -629,9 +633,7 @@ def test_has_redundant_keys(self): del temp[key] if temp.test_redundancy(key, disable_logging=True): logger.warning( - "{} is a redundant key in this {}.".format( - key, self.__class__.__name__ - ) + f"{key} is a redundant key in this {self.__class__.__name__}." ) redundant = True return redundant @@ -830,13 +832,12 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - from matplotlib.cbook import flatten - keys = list(keys) theta = list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() + joint = dict() for key, index in zip( self.sorted_keys_without_fixed_parameters, self._rescale_indexes ): @@ -844,7 +845,17 @@ def rescale(self, keys, theta): theta[index], **self.get_required_variables(key) ) self[key].least_recently_sampled = result[key] - return list(flatten([result[key] for key in keys])) + if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: + joint[self[key].dist.distname] = [key] + elif isinstance(self[key], JointPrior): + joint[self[key].dist.distname].append(key) + for names in joint.values(): + values = list() + for key in names: + values = np.concatenate([values, result[key]]) + for key, value in zip(names, values): + result[key] = value + return list([result[key] for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 187e8a60a..0b37656cd 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -1,8 +1,7 @@ import numpy as np -from scipy.interpolate import interp1d from .base import Prior -from ..utils import logger +from ..utils import logger, WrappedInterp1d as interp1d class Interped(Prior): @@ -86,10 +85,7 @@ def rescale(self, val): This maps to the inverse CDF. This is done using interpolation. """ - rescaled = self.inverse_cumulative_distribution(val) - if rescaled.shape == (): - rescaled = float(rescaled) - return rescaled + return self.inverse_cumulative_distribution(val) @property def minimum(self): diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 618061f48..9f8d2572d 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,7 +1,7 @@ import math from numbers import Number import numpy as np -from scipy.interpolate import interp2d +from scipy.interpolate import interp1d, interp2d from scipy.special import logsumexp from .log import logger @@ -264,6 +264,29 @@ def _sanitize_inputs(x, y): return x, y +class WrappedInterp1d(interp1d): + """ + A wrapper around scipy interp1d which sets equality-by-instantiation and + makes sure that the output is a float if the input is a float or int. + """ + def __call__(self, x): + output = super().__call__(x) + if isinstance(x, (float, int)): + output = output.item() + return output + + def __eq__(self, other): + for key in self.__dict__: + if type(self.__dict__[key]) is np.ndarray: + if not np.array_equal(self.__dict__[key], other.__dict__[key]): + return False + elif key == "_spline": + pass + elif getattr(self, key) != getattr(other, key): + return False + return True + + def round_up_to_power_of_two(x): """Round up to the next power of two diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 63b51790e..7a73456b3 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -2,7 +2,7 @@ import copy import numpy as np -from scipy.interpolate import InterpolatedUnivariateSpline, interp1d +from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm @@ -12,7 +12,7 @@ ConditionalPriorDict, ConditionalBasePrior, BaseJointPriorDist, JointPrior, JointPriorDistError, ) -from ..core.utils import infer_args_from_method, logger, random +from ..core.utils import infer_args_from_method, logger, random, WrappedInterp1d as interp1d from .conversion import ( convert_to_lal_binary_black_hole_parameters, convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters, @@ -377,21 +377,6 @@ def __init__(self, minimum, maximum, name='chirp_mass', name=name, latex_label=latex_label, unit=unit, boundary=boundary) -class WrappedInterp1d(interp1d): - """ A wrapper around scipy interp1d which sets equality-by-instantiation """ - def __eq__(self, other): - - for key in self.__dict__: - if type(self.__dict__[key]) is np.ndarray: - if not np.array_equal(self.__dict__[key], other.__dict__[key]): - return False - elif key == "_spline": - pass - elif getattr(self, key) != getattr(other, key): - return False - return True - - class UniformInComponentsMassRatio(Prior): r""" Prior distribution for chirp mass which is uniform in component masses. @@ -435,7 +420,7 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', latex_label=latex_label, unit=unit, boundary=boundary) self.norm = self._integral(maximum) - self._integral(minimum) qs = np.linspace(minimum, maximum, 1000) - self.icdf = WrappedInterp1d( + self.icdf = interp1d( self.cdf(qs), qs, kind='cubic', bounds_error=False, fill_value=(minimum, maximum)) @@ -449,11 +434,7 @@ def cdf(self, val): def rescale(self, val): if self.equal_mass: val = 2 * np.minimum(val, 1 - val) - resc = self.icdf(val) - if resc.ndim == 0: - return resc.item() - else: - return resc + return self.icdf(val) def prob(self, val): in_prior = (val >= self.minimum) & (val <= self.maximum) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 364d48e9e..319040c55 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -369,6 +369,15 @@ def test_cdf_zero_below_domain(self): ) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) + def test_cdf_float_with_float_input(self): + for prior in self.priors: + if ( + bilby.core.prior.JointPrior in prior.__class__.__mro__ + and prior.maximum == np.inf + ): + continue + self.assertIsInstance(prior.cdf(prior.sample()), float) + def test_log_normal_fail(self): with self.assertRaises(ValueError): bilby.core.prior.LogNormal(name="test", unit="unit", mu=0, sigma=-1) From adebfea5d9dbd2e37f1c8f13fcf35b6f2c68624d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 12 Jul 2024 16:22:00 -0500 Subject: [PATCH 02/13] FEAT: add fermi-dirac CDF --- bilby/core/prior/analytical.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 54148d1b1..8b03960ce 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1414,6 +1414,33 @@ def ln_prob(self, val): """ return np.log(self.prob(val)) + def cdf(self, val): + """ + Evaluate the CDF of the Fermi-Dirac distribution using a slightly + modified form of Equation 23 of [1]_. + + Parameters + ========== + val: Union[float, int, array_like] + The value(s) to evaluate the CDF at + + Returns + ======= + Union[float, array_like]: + The CDF value(s) + + References + ========== + + .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 + `_, 2017. + """ + result = ( + (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) + / np.logaddexp(0, self.r) + ) + return np.clip(result, 0, 1) + class Categorical(Prior): def __init__(self, ncategories, name=None, latex_label=None, From a58440bee6b73ae4d808004fb2e8d09847546548 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 12 Jul 2024 16:35:08 -0500 Subject: [PATCH 03/13] TST: add testing of fermi dirac and symmetric log uniform priors --- test/core/prior/prior_test.py | 39 +++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 319040c55..13fb6966c 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -86,6 +86,8 @@ def condition_func(reference_params, test_param): bilby.core.prior.Lorentzian(name="test", unit="unit", alpha=0, beta=1), bilby.core.prior.Gamma(name="test", unit="unit", k=1, theta=1), bilby.core.prior.ChiSquared(name="test", unit="unit", nu=2), + bilby.core.prior.FermiDirac(name="test", unit="unit", mu=1, sigma=1), + bilby.core.prior.SymmetricLogUniform(name="test", unit="unit", minimum=1e-2, maximum=1e2), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), @@ -228,6 +230,9 @@ def tearDown(self): def test_minimum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if bilby.core.prior.JointPrior in prior.__class__.__mro__: minimum_sample = prior.rescale(0) if prior.dist.filled_rescale(): @@ -252,6 +257,9 @@ def test_maximum_rescaling(self): def test_many_sample_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): @@ -270,6 +278,9 @@ def test_least_recently_sampled(self): def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue single_sample = prior.sample() self.assertTrue( (single_sample >= prior.minimum) & (single_sample <= prior.maximum) @@ -278,6 +289,9 @@ def test_sampling_single(self): def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue many_samples = prior.sample(5000) self.assertTrue( (all(many_samples >= prior.minimum)) @@ -300,6 +314,9 @@ def test_probability_above_domain(self): def test_probability_below_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if prior.minimum != -np.inf: outside_domain = np.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 @@ -358,6 +375,9 @@ def test_cdf_one_above_domain(self): def test_cdf_zero_below_domain(self): for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if ( bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf @@ -527,6 +547,9 @@ def test_probability_surrounding_domain(self): # skip delta function prior in this case if isinstance(prior, bilby.core.prior.DeltaFunction): continue + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) indomain = (surround_domain >= prior.minimum) | ( surround_domain <= prior.maximum @@ -541,11 +564,18 @@ def test_probability_surrounding_domain(self): self.assertTrue(all(prior.prob(surround_domain[outdomain]) == 0)) def test_normalized(self): - """Test that each of the priors are normalised, this needs care for delta function and Gaussian priors""" + """ + Test that each of the priors are normalised. + This needs extra care for priors defined on infinite domains and the + Cauchy, DeltaFunction, and SymmetricLogUniform priors are skipped + because they are too sharply peaked to be tested efficiently in this way. + """ for prior in self.priors: - if isinstance(prior, bilby.core.prior.DeltaFunction): - continue - if isinstance(prior, bilby.core.prior.Cauchy): + if isinstance(prior, ( + bilby.core.prior.DeltaFunction, + bilby.core.prior.Cauchy, + bilby.core.prior.SymmetricLogUniform + )): continue if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue @@ -757,6 +787,7 @@ def test_set_minimum_setting(self): bilby.core.prior.MultivariateGaussian, bilby.core.prior.FermiDirac, bilby.core.prior.Triangular, + bilby.core.prior.SymmetricLogUniform, bilby.gw.prior.HealPixPrior, ), ): From 980e0cb01e2cb4d50818b853c2d9a555ec4a158d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 12 Jul 2024 16:38:40 -0500 Subject: [PATCH 04/13] FMT: remove extraneous whitespace --- bilby/core/prior/analytical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 8b03960ce..cb3c38462 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1423,7 +1423,7 @@ def cdf(self, val): ========== val: Union[float, int, array_like] The value(s) to evaluate the CDF at - + Returns ======= Union[float, array_like]: From e93329935050a02623947d3b21e4f90fcd64548e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 10:05:41 -0600 Subject: [PATCH 05/13] BUG: revert bad changes to equal comparisons --- bilby/core/prior/analytical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 6b26badf5..0775a3f4d 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -375,7 +375,7 @@ def cdf(self, val): * (val <= -self.minimum) * (val >= -self.maximum) + (0.5 + norm * np.log(abs(val) / self.minimum)) * (val >= self.minimum) * (val <= self.maximum) - + 0.5 * (val >= -self.minimum) * (val <= self.minimum) + + 0.5 * (val > -self.minimum) * (val < self.minimum) + 1 * (val > self.maximum) ) return _cdf From 4b24289b953f61120c243c58872a8b0e3404ff6a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 10:33:40 -0600 Subject: [PATCH 06/13] BUG: Fix syntax error in return statement --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 50135099b..893a83e9f 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -866,7 +866,7 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - return [list(np.asarray(result[key]).flatten()) for key in keys]) + return [list(np.asarray(result[key]).flatten()) for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From ebf93d86663e883ce8c040a006511ee13e7f3f6b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 11:11:30 -0600 Subject: [PATCH 07/13] BUG: Fix list return --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 893a83e9f..5631f9470 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -866,7 +866,7 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - return [list(np.asarray(result[key]).flatten()) for key in keys] + return [np.asarray(result[key]).flatten() for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 388e3fff19b6f67e231efe14713ca46ac3889ace Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 23 Jan 2025 11:38:59 -0600 Subject: [PATCH 08/13] BUG: fix array type output for rescale --- bilby/core/prior/dict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 5631f9470..f798cbe1f 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -866,7 +866,10 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - return [np.asarray(result[key]).flatten() for key in keys] + # this is gross but can be removed whenever we switch to returning + # arrays, flatten converts 0-d arrays to 1-d and squeeze converts it + # back + return [np.asarray(result[key]).flatten().squeeze for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 81951d22040624c7873b6be8d99c5be46c24a59f Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Jan 2025 02:14:14 -0600 Subject: [PATCH 09/13] BUG: Fix missing parentheses in squeeze method call --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index f798cbe1f..e7bf3bdf0 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -869,7 +869,7 @@ def rescale(self, keys, theta): # this is gross but can be removed whenever we switch to returning # arrays, flatten converts 0-d arrays to 1-d and squeeze converts it # back - return [np.asarray(result[key]).flatten().squeeze for key in keys] + return [np.asarray(result[key]).flatten().squeeze() for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 433471866267c1c313b01c3eed610f70fb837fc5 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 15:16:43 +0000 Subject: [PATCH 10/13] BUG: fix flattening logic --- bilby/core/prior/dict.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index e7bf3bdf0..28876330d 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -866,10 +866,19 @@ def rescale(self, keys, theta): values = np.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value - # this is gross but can be removed whenever we switch to returning - # arrays, flatten converts 0-d arrays to 1-d and squeeze converts it - # back - return [np.asarray(result[key]).flatten().squeeze() for key in keys] + + def safe_flatten(value): + """ + this is gross but can be removed whenever we switch to returning + arrays, flatten converts 0-d arrays to 1-d so has to be special + cased + """ + if isinstance(value, (float, int)): + return value + else: + return result[key].flatten() + + return [safe_flatten(result[key]) for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 2a1809ca2d828eb49480adfb01bd87677654a747 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 15:17:09 +0000 Subject: [PATCH 11/13] BUG: stop using arr[np.where(cond)] --- bilby/core/prior/slabspike.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 92664b15e..6910be608 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -88,11 +88,11 @@ def rescale(self, val): original_is_number = isinstance(val, Number) val = np.atleast_1d(val) - lower_indices = np.where(val < self.inverse_cdf_below_spike)[0] - intermediate_indices = np.where(np.logical_and( + lower_indices = val < self.inverse_cdf_below_spike + intermediate_indices = np.logical_and( self.inverse_cdf_below_spike <= val, - val <= self.inverse_cdf_below_spike + self.spike_height))[0] - higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0] + val <= (self.inverse_cdf_below_spike + self.spike_height)) + higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) res = np.zeros(len(val)) res[lower_indices] = self._contracted_rescale(val[lower_indices]) @@ -137,7 +137,7 @@ def prob(self, val): original_is_number = isinstance(val, Number) res = self.slab.prob(val) * self.slab_fraction res = np.atleast_1d(res) - res[np.where(val == self.spike_location)] = np.inf + res[val == self.spike_location] = np.inf if original_is_number: try: res = res[0] @@ -161,7 +161,7 @@ def ln_prob(self, val): original_is_number = isinstance(val, Number) res = self.slab.ln_prob(val) + np.log(self.slab_fraction) res = np.atleast_1d(res) - res[np.where(val == self.spike_location)] = np.inf + res[val == self.spike_location] = np.inf if original_is_number: try: res = res[0] @@ -185,7 +185,5 @@ def cdf(self, val): """ res = self.slab.cdf(val) * self.slab_fraction - res = np.atleast_1d(res) - indices_above_spike = np.where(val > self.spike_location)[0] - res[indices_above_spike] += self.spike_height + res += self.spike_height * (val > self.spike_location) return res From e38b6de62dc2814ca5ee23cbaa74c8022bc18657 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 19 Feb 2025 13:39:37 -0600 Subject: [PATCH 12/13] Remove unused import interp1d --- bilby/gw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index e07f822d7..efbfded85 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -3,7 +3,7 @@ import numpy as np from scipy.integrate import quad -from scipy.interpolate import InterpolatedUnivariateSpline, interp1d +from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm From 7db3439c34c672d81737edfa9ff077443e035ec1 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 21 Aug 2025 08:56:20 -0500 Subject: [PATCH 13/13] Address comments --- bilby/core/prior/dict.py | 6 ++++++ test/core/prior/prior_test.py | 5 +---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 4b3a3af7c..d037dc985 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -880,6 +880,12 @@ def rescale(self, keys, theta): elif isinstance(self[key], JointPrior): joint[self[key].dist.distname].append(key) for names in joint.values(): + # this is needed to unpack how joint prior rescaling works + # as an example of a joint prior over {a, b, c, d} we might + # get the following based on the order within the joint prior + # {a: [], b: [], c: [1, 2, 3, 4], d: []} + # -> [1, 2, 3, 4] + # -> {a: 1, b: 2, c: 3, d: 4} values = list() for key in names: values = np.concatenate([values, result[key]]) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 0e81c6a06..333defdb6 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -402,10 +402,7 @@ def test_cdf_zero_below_domain(self): def test_cdf_float_with_float_input(self): for prior in self.priors: - if ( - bilby.core.prior.JointPrior in prior.__class__.__mro__ - and prior.maximum == np.inf - ): + if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue self.assertIsInstance(prior.cdf(prior.sample()), float)