diff --git a/gwpopulation/hyperpe.py b/gwpopulation/hyperpe.py index 3cb54395..d2ce2c40 100644 --- a/gwpopulation/hyperpe.py +++ b/gwpopulation/hyperpe.py @@ -48,7 +48,7 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, Parameters ---------- posteriors: list - An list of pandas data frames of samples sets of samples. + A list of pandas data frames of samples sets of samples. Each set may have a different size. These can contain a `prior` column containing the original prior values. @@ -111,7 +111,7 @@ def log_likelihood_ratio(self): self.parameters, added_keys = self.conversion_function(self.parameters) self.hyper_prior.parameters.update(self.parameters) ln_l = xp.sum(self._compute_per_event_ln_bayes_factors()) - ln_l += self._get_selection_factor() + ln_l += self.n_posteriors * self._get_selection_factor() if added_keys is not None: for key in added_keys: self.parameters.pop(key) @@ -129,7 +129,7 @@ def _compute_per_event_ln_bayes_factors(self): self.sampling_prior, axis=-1)) def _get_selection_factor(self): - return - self.n_posteriors * xp.log( + return - xp.log( self.selection_function(self.parameters)) def generate_extra_statistics(self, sample): @@ -274,3 +274,66 @@ def _get_selection_factor(self): def generate_rate_posterior_sample(self): pass + + +class PastroLikelihood(HyperparameterLikelihood): + """ + A mixture model likelihood which utlises p_astro for the effective + noise likelihood. + This likelihood is marginalised over rate. + + See Eq. (40) of https://arxiv.org/abs/1912.09708 for a definition. + """ + def __init__(self, posteriors, hyper_prior, sampling_prior=None, + ln_evidences=None, max_samples=1e100, pastro=None, + selection_function=lambda args: 1, + conversion_function=lambda args: (args, None), + fiducial_selection=None, cupy=True): + """ + Parameters + ---------- + pastro: array + An array of pastro values corresponding to the event posteriors. + fiducial_selection: float + The visible spacetime volume described by the fiducial model + determined by the sampling prior distribution. + """ + if pastro is not None: + if fiducial_selection is not None: + logger.info('Fiducial VT set to {}'.format(fiducial_selection)) + self.fiducial_selection = fiducial_selection + else: + logger.warning('No fiducial VT provided, defaulting to 1.0') + self.fiducial_selection = 1.0 + if len(pastro) != len(posteriors): + raise ValueError('Number of pastro values provided are not ' + 'equal to the number of posteriors.') + self.pastro = pastro + + super().__init__(self) + + def log_likelihood_ratio(self): + self.parameters, added_keys = self.conversion_function(self.parameters) + self.hyper_prior.parameters.update(self.parameters) + ln_l1 = self._compute_per_event_ln_bayes_factors() + ln_l1 += self._get_selection_factor() + ln_l1 += self._get_fiducial_selection_factor() + + if added_keys is not None: + for key in added_keys: + self.parameters.pop(key) + + ln_l2 = self._get_pastro_factor + ln_l2 = xp.nan_to_num(ln_l2) + + ln_l = xp.sum(xp.logaddexp(ln_l1, ln_l2)) + ln_l += 0.5 * self._get_selection_factor() + + return float(xp.nan_to_num(ln_l)) + + def _get_pastro_factor(self): + pastro_factor = (1.0 - self.pastro)/self.pastro + return xp.log(pastro_factor) + + def _get_fiducial_selection_factor(self): + return xp.log(self.fiducial_selection)