Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge

dependencies:
- python >= 3.11
- python >= 3.11,<3.14
- mpich
- lhapdf
- pandoc
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge

dependencies:
- python >= 3.11
- python >= 3.11,<3.14
- mpich
- lhapdf
- pandoc
Expand Down
1 change: 0 additions & 1 deletion grid_pdf/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from colibri.app import colibriApp
from grid_pdf.config import GridPdfConfig


grid_pdf_providers = [
"grid_pdf.model",
"grid_pdf.utils",
Expand Down
22 changes: 21 additions & 1 deletion grid_pdf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from validphys.core import PDF

import colibri.bayes_prior
from colibri.core import BayesianPrior


def closure_test_central_pdf_grid(
Expand Down Expand Up @@ -138,6 +139,14 @@ def bayesian_prior(prior_settings, pdf_model):
error_up = mean + delta * nsigma
error_down = mean - delta * nsigma

# Define dummy log_prob and sample for now
@jax.jit
def log_prob(x):
raise NotImplementedError("log_prob not implemented for Gaussian prior")

def sample(rng_key, n_samples):
raise NotImplementedError("sample not implemented for Gaussian prior")

@jax.jit
def prior_transform(cube):
params = error_down + (error_up - error_down) * cube
Expand All @@ -163,6 +172,13 @@ def prior_transform(cube):
)
cholesky_pdf_covmat = jnp.diag(jnp.sqrt(pdf_diag_covmat_prior))

@jax.jit
def log_prob(x):
raise NotImplementedError("log_prob not implemented for Gaussian prior")

def sample(rng_key, n_samples):
raise NotImplementedError("sample not implemented for Gaussian prior")

@jax.jit
def prior_transform(cube):
"""
Expand All @@ -181,7 +197,11 @@ def prior_transform(cube):
else:
return colibri.bayes_prior.bayesian_prior(prior_settings)

return prior_transform
return BayesianPrior(
prior_transform=prior_transform,
log_prob=log_prob,
sample=sample,
)


def pdf_initial_parameters(pdf_model, param_initialiser_settings, replica_index):
Expand Down
Loading