Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions .github/workflows/latest-dynesty.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Test against latest dynesty

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
merge_group:
release:
types:
- published

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

env:
CONDA_PATH: /opt/conda/

jobs:
build:

name: ${{ matrix.python.name }} unit tests
runs-on: ubuntu-latest
container: ghcr.io/bilby-dev/bilby-python${{ matrix.python.short-version }}:latest
strategy:
fail-fast: false
matrix:
python:
- name: Python 3.12
version: 3.12
short-version: 312

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-tags: true
- name: Install package
run: |
# activate env so that conda list shows the correct environment
source $CONDA_PATH/bin/activate python${{ matrix.python.short-version }}
python -m pip install .
python -m pip install git+https://github.com/joshspeagle/dynesty@master
conda list --show-channel-urls
shell: bash
- name: Run unit tests
run: |
python -m pytest --durations 10 -k dynesty
- name: Run sampler tests
run: |
python -m pytest test/integration/sampler_run_test.py --durations 10 -v -k dynesty
2 changes: 1 addition & 1 deletion bilby/core/sampler/dynamic_dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def nlive(self):

@property
def sampler_init(self):
from dynesty import DynamicNestedSampler
from dynesty.dynesty import DynamicNestedSampler

return DynamicNestedSampler

Expand Down
52 changes: 46 additions & 6 deletions bilby/core/sampler/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def default_kwargs(self):
kwargs["seed"] = None
return kwargs

@property
def new_dynesty_api(self):
try:
import dynesty.internal_samplers # noqa

Comment thread
mj-will marked this conversation as resolved.
return True
except ImportError:
return False

def __init__(
self,
likelihood,
Expand Down Expand Up @@ -261,7 +270,20 @@ def sampler_function_kwargs(self):

@property
def sampler_init_kwargs(self):
return {key: self.kwargs[key] for key in self._dynesty_init_kwargs}
kwargs = {key: self.kwargs[key] for key in self._dynesty_init_kwargs}
if self.new_dynesty_api:
from . import dynesty3_utils as dynesty_utils

if kwargs["sample"] == "act-walk":
kwargs["sample"] = dynesty_utils.ACTTrackingEnsembleWalk(**kwargs)
kwargs["bound"] = "none"
elif kwargs["sample"] == "acceptance-walk":
kwargs["sample"] = dynesty_utils.EnsembleWalkSampler(**kwargs)
kwargs["bound"] = "none"
elif kwargs["sample"] == "rwalk":
kwargs["sample"] = dynesty_utils.AcceptanceTrackingRWalk(**kwargs)
kwargs["bound"] = "none"
return kwargs

def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
Expand Down Expand Up @@ -429,7 +451,7 @@ def nlive(self):

@property
def sampler_init(self):
from dynesty import NestedSampler
from dynesty.dynesty import NestedSampler

return NestedSampler

Expand All @@ -450,6 +472,9 @@ def _set_sampling_method(self):
Additionally, some combinations of bound/sample/proposals are not
compatible and so we either warn the user or raise an error.
"""
if self.new_dynesty_api:
return

import dynesty

_set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))
Expand Down Expand Up @@ -595,6 +620,10 @@ def _setup_pool(self):
more times than we have processes.
"""
super(Dynesty, self)._setup_pool()

if self.new_dynesty_api:
return

if self.pool is not None:
args = (
[(self.nact, self.maxmcmc, self.proposals, self.naccept)]
Expand Down Expand Up @@ -793,8 +822,17 @@ def read_saved_state(self, continuing=False):
if continuing:
self._remove_live()
self.sampler.nqueue = -1
self.start_time = self.sampler.kwargs.pop("start_time")
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
if hasattr(self.sampler, "_bilby_metadata"):
extras = self.sampler._bilby_metadata
elif hasattr(self.sampler, "kwargs"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused here. Why can the sampler have either metadata or kwargs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new API the dynesty sampler classes don't have a kwargs attribute that we can piggyback on.

extras = self.sampler.kwargs
else:
raise AttributeError(
"Loaded sampler doesn't contain timing info, "
"the checkpoint is probably corrupted."
)
self.start_time = extras.pop("start_time")
self.sampling_time = extras.pop("sampling_time")
self.sampler.queue_size = self.kwargs["queue_size"]
self.sampler.pool = self.pool
if self.pool is not None:
Expand Down Expand Up @@ -835,8 +873,10 @@ def write_current_state(self):
check_directory_exists_and_if_not_mkdir(self.outdir)
if hasattr(self, "start_time"):
self._update_sampling_time()
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
self.sampler._bilby_metadata = dict(
sampling_time=self.sampling_time,
start_time=self.start_time,
)
self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version)
self.sampler.pool = None
self.sampler.M = map
Expand Down
Loading
Loading