diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 547339dcb..0775a3f4d 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -169,8 +169,10 @@ def cdf(self, val): _cdf = (np.log(val / self.minimum) / np.log(self.maximum / self.minimum)) else: - _cdf = np.atleast_1d(val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / \ - (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + _cdf = ( + (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + ) _cdf = np.minimum(_cdf, 1) _cdf = np.maximum(_cdf, 0) return _cdf @@ -367,16 +369,16 @@ def ln_prob(self, val): return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) def cdf(self, val): - val = np.atleast_1d(val) norm = 0.5 / np.log(self.maximum / self.minimum) - cdf = np.zeros((len(val))) - lower_indices = np.where(np.logical_and(-self.maximum <= val, val <= -self.minimum))[0] - upper_indices = np.where(np.logical_and(self.minimum <= val, val <= self.maximum))[0] - cdf[lower_indices] = -norm * np.log(-val[lower_indices] / self.maximum) - cdf[np.where(np.logical_and(-self.minimum < val, val < self.minimum))] = 0.5 - cdf[upper_indices] = 0.5 + norm * np.log(val[upper_indices] / self.minimum) - cdf[np.where(self.maximum < val)] = 1 - return cdf + _cdf = ( + -norm * np.log(abs(val) / self.maximum) + * (val <= -self.minimum) * (val >= -self.maximum) + + (0.5 + norm * np.log(abs(val) / self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 0.5 * (val > -self.minimum) * (val < self.minimum) + + 1 * (val > self.maximum) + ) + return _cdf class Cosine(Prior): @@ -426,10 +428,12 @@ def prob(self, val): return np.cos(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = np.atleast_1d((np.sin(val) - np.sin(self.minimum)) / - (np.sin(self.maximum) - np.sin(self.minimum))) - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + (np.sin(val) - np.sin(self.minimum)) + / (np.sin(self.maximum) - np.sin(self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -480,10 +484,12 @@ def prob(self, val): return np.sin(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = np.atleast_1d((np.cos(val) - np.cos(self.minimum)) / - (np.cos(self.maximum) - np.cos(self.minimum))) - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + (np.cos(val) - np.cos(self.minimum)) + / (np.cos(self.maximum) - np.cos(self.minimum)) + * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -625,11 +631,13 @@ def prob(self, val): / self.sigma / self.normalisation * self.is_in_prior_range(val) def cdf(self, val): - val = np.atleast_1d(val) - _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( - (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation - _cdf[val > self.maximum] = 1 - _cdf[val < self.minimum] = 0 + _cdf = ( + ( + erf((val - self.mu) / 2 ** 0.5 / self.sigma) + - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) + ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) + + 1 * (val > self.maximum) + ) return _cdf @@ -1367,6 +1375,8 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") + self.expr = np.exp(self.r) + def rescale(self, val): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1384,21 +1394,8 @@ def rescale(self, val): .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 `_, 2017. """ - inv = (-np.exp(-1. * self.r) + (1. + np.exp(self.r)) ** -val + - np.exp(-1. * self.r) * (1. + np.exp(self.r)) ** -val) - - # if val is 1 this will cause inv to be negative (due to numerical - # issues), so return np.inf - if isinstance(val, (float, int)): - if inv < 0: - return np.inf - else: - return -self.sigma * np.log(inv) - else: - idx = inv >= 0. - tmpinv = np.inf * np.ones(len(np.atleast_1d(val))) - tmpinv[idx] = -self.sigma * np.log(inv[idx]) - return tmpinv + inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr + return -self.sigma * np.log(np.maximum(inv, 0)) def prob(self, val): """Return the prior probability of val. @@ -1411,7 +1408,11 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return ( + (np.exp((val - self.mu) / self.sigma) + 1)**-1 + / (self.sigma * np.log1p(self.expr)) + * (val >= self.minimum) + ) def ln_prob(self, val): """Return the log prior probability of val. @@ -1424,19 +1425,34 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ + return np.log(self.prob(val)) - norm = -np.log(self.sigma * np.log(1. + np.exp(self.r))) - if isinstance(val, (float, int)): - if val < self.minimum: - return -np.inf - else: - return norm - np.logaddexp((val / self.sigma) - self.r, 0.) - else: - val = np.atleast_1d(val) - lnp = -np.inf * np.ones(len(val)) - idx = val >= self.minimum - lnp[idx] = norm - np.logaddexp((val[idx] / self.sigma) - self.r, 0.) - return lnp + def cdf(self, val): + """ + Evaluate the CDF of the Fermi-Dirac distribution using a slightly + modified form of Equation 23 of [1]_. + + Parameters + ========== + val: Union[float, int, array_like] + The value(s) to evaluate the CDF at + + Returns + ======= + Union[float, array_like]: + The CDF value(s) + + References + ========== + + .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 + `_, 2017. + """ + result = ( + (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) + / np.logaddexp(0, self.r) + ) + return np.clip(result, 0, 1) class Categorical(Prior): diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 3ea4e19b9..0be3999a4 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -5,7 +5,6 @@ import numpy as np import scipy.stats -from scipy.interpolate import interp1d from ..utils import ( infer_args_from_method, @@ -13,6 +12,7 @@ decode_bilby_json, logger, get_dict_with_properties, + WrappedInterp1d as interp1d, ) @@ -178,7 +178,10 @@ def cdf(self, val): cdf = cumulative_trapezoid(pdf, x, initial=0) interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, fill_value=(0, 1)) - return interp(val) + output = interp(val) + if isinstance(val, (int, float)): + output = float(output) + return output def ln_prob(self, val): """Return the prior ln probability of val, this should be overwritten diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 4bb514c34..d037dc985 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -487,7 +487,9 @@ def check_efficiency(n_tested, n_valid): def normalize_constraint_factor( self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10 ): - if keys in self._cached_normalizations.keys(): + if len(self.constraint_keys) == 0: + return 1 + elif keys in self._cached_normalizations.keys(): return self._cached_normalizations[keys] else: factor_estimates = [ @@ -549,8 +551,10 @@ def check_prob(self, sample, prob): return 0.0 else: constrained_prob = np.zeros_like(prob) - keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_prob[keep] = prob[keep] * ratio + in_bounds = np.isfinite(prob) + subsample = {key: sample[key][in_bounds] for key in sample} + keep = np.array(self.evaluate_constraints(subsample), dtype=bool) + constrained_prob[in_bounds] = prob[in_bounds] * keep * ratio return constrained_prob def ln_prob(self, sample, axis=None, normalized=True): @@ -591,8 +595,10 @@ def check_ln_prob(self, sample, ln_prob, normalized=True): return -np.inf else: constrained_ln_prob = -np.inf * np.ones_like(ln_prob) - keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) + in_bounds = np.isfinite(ln_prob) + subsample = {key: sample[key][in_bounds] for key in sample} + keep = np.log(np.array(self.evaluate_constraints(subsample), dtype=bool)) + constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) return constrained_ln_prob def cdf(self, sample): @@ -653,9 +659,7 @@ def test_has_redundant_keys(self): del temp[key] if temp.test_redundancy(key, disable_logging=True): logger.warning( - "{} is a redundant key in this {}.".format( - key, self.__class__.__name__ - ) + f"{key} is a redundant key in this {self.__class__.__name__}." ) redundant = True return redundant @@ -863,6 +867,7 @@ def rescale(self, keys, theta): self._check_resolved() self._update_rescale_keys(keys) result = dict() + joint = dict() for key, index in zip( self.sorted_keys_without_fixed_parameters, self._rescale_indexes ): @@ -870,10 +875,35 @@ def rescale(self, keys, theta): theta[index], **self.get_required_variables(key) ) self[key].least_recently_sampled = result[key] - samples = [] - for key in keys: - samples += list(np.asarray(result[key]).flatten()) - return samples + if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: + joint[self[key].dist.distname] = [key] + elif isinstance(self[key], JointPrior): + joint[self[key].dist.distname].append(key) + for names in joint.values(): + # this is needed to unpack how joint prior rescaling works + # as an example of a joint prior over {a, b, c, d} we might + # get the following based on the order within the joint prior + # {a: [], b: [], c: [1, 2, 3, 4], d: []} + # -> [1, 2, 3, 4] + # -> {a: 1, b: 2, c: 3, d: 4} + values = list() + for key in names: + values = np.concatenate([values, result[key]]) + for key, value in zip(names, values): + result[key] = value + + def safe_flatten(value): + """ + this is gross but can be removed whenever we switch to returning + arrays, flatten converts 0-d arrays to 1-d so has to be special + cased + """ + if isinstance(value, (float, int)): + return value + else: + return result[key].flatten() + + return [safe_flatten(result[key]) for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 2cee669d9..6a7b383a5 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -1,8 +1,7 @@ import numpy as np -from scipy.interpolate import interp1d from .base import Prior -from ..utils import logger +from ..utils import logger, WrappedInterp1d as interp1d class Interped(Prior): @@ -86,10 +85,7 @@ def rescale(self, val): This maps to the inverse CDF. This is done using interpolation. """ - rescaled = self.inverse_cumulative_distribution(val) - if rescaled.shape == (): - rescaled = float(rescaled) - return rescaled + return self.inverse_cumulative_distribution(val) @property def minimum(self): diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 92664b15e..6910be608 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -88,11 +88,11 @@ def rescale(self, val): original_is_number = isinstance(val, Number) val = np.atleast_1d(val) - lower_indices = np.where(val < self.inverse_cdf_below_spike)[0] - intermediate_indices = np.where(np.logical_and( + lower_indices = val < self.inverse_cdf_below_spike + intermediate_indices = np.logical_and( self.inverse_cdf_below_spike <= val, - val <= self.inverse_cdf_below_spike + self.spike_height))[0] - higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0] + val <= (self.inverse_cdf_below_spike + self.spike_height)) + higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) res = np.zeros(len(val)) res[lower_indices] = self._contracted_rescale(val[lower_indices]) @@ -137,7 +137,7 @@ def prob(self, val): original_is_number = isinstance(val, Number) res = self.slab.prob(val) * self.slab_fraction res = np.atleast_1d(res) - res[np.where(val == self.spike_location)] = np.inf + res[val == self.spike_location] = np.inf if original_is_number: try: res = res[0] @@ -161,7 +161,7 @@ def ln_prob(self, val): original_is_number = isinstance(val, Number) res = self.slab.ln_prob(val) + np.log(self.slab_fraction) res = np.atleast_1d(res) - res[np.where(val == self.spike_location)] = np.inf + res[val == self.spike_location] = np.inf if original_is_number: try: res = res[0] @@ -185,7 +185,5 @@ def cdf(self, val): """ res = self.slab.cdf(val) * self.slab_fraction - res = np.atleast_1d(res) - indices_above_spike = np.where(val > self.spike_location)[0] - res[indices_above_spike] += self.spike_height + res += self.spike_height * (val > self.spike_location) return res diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index ac6fcefcd..e47a51d3e 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,7 +1,7 @@ import math import numpy as np -from scipy.interpolate import RectBivariateSpline +from scipy.interpolate import RectBivariateSpline, interp1d from scipy.special import logsumexp from .log import logger @@ -219,6 +219,29 @@ def __call__(self, x, y, dx=0, dy=0, grid=False): return result +class WrappedInterp1d(interp1d): + """ + A wrapper around scipy interp1d which sets equality-by-instantiation and + makes sure that the output is a float if the input is a float or int. + """ + def __call__(self, x): + output = super().__call__(x) + if isinstance(x, (float, int)): + output = output.item() + return output + + def __eq__(self, other): + for key in self.__dict__: + if type(self.__dict__[key]) is np.ndarray: + if not np.array_equal(self.__dict__[key], other.__dict__[key]): + return False + elif key == "_spline": + pass + elif getattr(self, key) != getattr(other, key): + return False + return True + + def round_up_to_power_of_two(x): """Round up to the next power of two diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 6885af67b..08d7178e0 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -3,7 +3,7 @@ import numpy as np from scipy.integrate import quad -from scipy.interpolate import InterpolatedUnivariateSpline, interp1d +from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm @@ -13,7 +13,7 @@ ConditionalPriorDict, ConditionalBasePrior, BaseJointPriorDist, JointPrior, JointPriorDistError, ) -from ..core.utils import infer_args_from_method, logger, random +from ..core.utils import infer_args_from_method, logger, random, WrappedInterp1d as interp1d from .conversion import ( convert_to_lal_binary_black_hole_parameters, convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters, @@ -379,21 +379,6 @@ def __init__(self, minimum, maximum, name='chirp_mass', name=name, latex_label=latex_label, unit=unit, boundary=boundary) -class WrappedInterp1d(interp1d): - """ A wrapper around scipy interp1d which sets equality-by-instantiation """ - def __eq__(self, other): - - for key in self.__dict__: - if type(self.__dict__[key]) is np.ndarray: - if not np.array_equal(self.__dict__[key], other.__dict__[key]): - return False - elif key == "_spline": - pass - elif getattr(self, key) != getattr(other, key): - return False - return True - - class UniformInComponentsMassRatio(Prior): r""" Prior distribution for chirp mass which is uniform in component masses. @@ -437,7 +422,7 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', latex_label=latex_label, unit=unit, boundary=boundary) self.norm = self._integral(maximum) - self._integral(minimum) qs = np.linspace(minimum, maximum, 1000) - self.icdf = WrappedInterp1d( + self.icdf = interp1d( self.cdf(qs), qs, kind='cubic', bounds_error=False, fill_value=(minimum, maximum)) @@ -451,11 +436,7 @@ def cdf(self, val): def rescale(self, val): if self.equal_mass: val = 2 * np.minimum(val, 1 - val) - resc = self.icdf(val) - if resc.ndim == 0: - return resc.item() - else: - return resc + return self.icdf(val) def prob(self, val): in_prior = (val >= self.minimum) & (val <= self.maximum) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 37c6d93f3..333defdb6 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -86,6 +86,8 @@ def condition_func(reference_params, test_param): bilby.core.prior.Lorentzian(name="test", unit="unit", alpha=0, beta=1), bilby.core.prior.Gamma(name="test", unit="unit", k=1, theta=1), bilby.core.prior.ChiSquared(name="test", unit="unit", nu=2), + bilby.core.prior.FermiDirac(name="test", unit="unit", mu=1, sigma=1), + bilby.core.prior.SymmetricLogUniform(name="test", unit="unit", minimum=1e-2, maximum=1e2), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), bilby.gw.prior.AlignedSpin( a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), @@ -235,7 +237,10 @@ def tearDown(self): def test_minimum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: - if isinstance(prior, bilby.gw.prior.AlignedSpin): + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue + elif isinstance(prior, bilby.gw.prior.AlignedSpin): # the edge of the prior is extremely suppressed for these priors # and so the rescale function doesn't quite return the lower bound continue @@ -263,6 +268,9 @@ def test_maximum_rescaling(self): def test_many_sample_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): @@ -281,6 +289,9 @@ def test_least_recently_sampled(self): def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue single_sample = prior.sample() self.assertTrue( (single_sample >= prior.minimum) & (single_sample <= prior.maximum) @@ -289,6 +300,9 @@ def test_sampling_single(self): def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue many_samples = prior.sample(5000) self.assertTrue( (all(many_samples >= prior.minimum)) @@ -311,6 +325,9 @@ def test_probability_above_domain(self): def test_probability_below_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if prior.minimum != -np.inf: outside_domain = np.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 @@ -369,6 +386,9 @@ def test_cdf_one_above_domain(self): def test_cdf_zero_below_domain(self): for prior in self.priors: + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue if ( bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf @@ -380,6 +400,12 @@ def test_cdf_zero_below_domain(self): ) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) + def test_cdf_float_with_float_input(self): + for prior in self.priors: + if bilby.core.prior.JointPrior in prior.__class__.__mro__: + continue + self.assertIsInstance(prior.cdf(prior.sample()), float) + def test_log_normal_fail(self): with self.assertRaises(ValueError): bilby.core.prior.LogNormal(name="test", unit="unit", mu=0, sigma=-1) @@ -529,6 +555,9 @@ def test_probability_surrounding_domain(self): # skip delta function prior in this case if isinstance(prior, bilby.core.prior.DeltaFunction): continue + if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): + # SymmetricLogUniform has support down to -maximum + continue surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) indomain = (surround_domain >= prior.minimum) | ( surround_domain <= prior.maximum @@ -543,11 +572,18 @@ def test_probability_surrounding_domain(self): self.assertTrue(all(prior.prob(surround_domain[outdomain]) == 0)) def test_normalized(self): - """Test that each of the priors are normalised, this needs care for delta function and Gaussian priors""" + """ + Test that each of the priors are normalised. + This needs extra care for priors defined on infinite domains and the + Cauchy, DeltaFunction, and SymmetricLogUniform priors are skipped + because they are too sharply peaked to be tested efficiently in this way. + """ for prior in self.priors: - if isinstance(prior, bilby.core.prior.DeltaFunction): - continue - if isinstance(prior, bilby.core.prior.Cauchy): + if isinstance(prior, ( + bilby.core.prior.DeltaFunction, + bilby.core.prior.Cauchy, + bilby.core.prior.SymmetricLogUniform + )): continue if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue @@ -761,6 +797,7 @@ def test_set_minimum_setting(self): bilby.core.prior.MultivariateGaussian, bilby.core.prior.FermiDirac, bilby.core.prior.Triangular, + bilby.core.prior.SymmetricLogUniform, bilby.gw.prior.HealPixPrior, ), ):