-
Notifications
You must be signed in to change notification settings - Fork 129
ENH: add support for new dynesty api #950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1afa76e
c89ad56
5998eb7
88489fe
e245dae
e61c7f2
f75da57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
| # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | ||
|
|
||
| name: Test against latest dynesty | ||
|
|
||
| on: | ||
| push: | ||
| branches: [ "main" ] | ||
| pull_request: | ||
| branches: [ "main" ] | ||
| merge_group: | ||
| release: | ||
| types: | ||
| - published | ||
|
|
||
| concurrency: | ||
| group: ${{ github.workflow }}-${{ github.ref }} | ||
| cancel-in-progress: true | ||
|
|
||
| env: | ||
| CONDA_PATH: /opt/conda/ | ||
|
|
||
| jobs: | ||
| build: | ||
|
|
||
| name: ${{ matrix.python.name }} unit tests | ||
| runs-on: ubuntu-latest | ||
| container: ghcr.io/bilby-dev/bilby-python${{ matrix.python.short-version }}:latest | ||
| strategy: | ||
| fail-fast: false | ||
| matrix: | ||
| python: | ||
| - name: Python 3.12 | ||
| version: 3.12 | ||
| short-version: 312 | ||
|
|
||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| with: | ||
| fetch-depth: 0 | ||
| fetch-tags: true | ||
| - name: Install package | ||
| run: | | ||
| # activate env so that conda list shows the correct environment | ||
| source $CONDA_PATH/bin/activate python${{ matrix.python.short-version }} | ||
| python -m pip install . | ||
| python -m pip install git+https://github.com/joshspeagle/dynesty@master | ||
| conda list --show-channel-urls | ||
| shell: bash | ||
| - name: Run unit tests | ||
| run: | | ||
| python -m pytest --durations 10 -k dynesty | ||
| - name: Run sampler tests | ||
| run: | | ||
| python -m pytest test/integration/sampler_run_test.py --durations 10 -v -k dynesty |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -189,6 +189,15 @@ def default_kwargs(self): | |
| kwargs["seed"] = None | ||
| return kwargs | ||
|
|
||
| @property | ||
| def new_dynesty_api(self): | ||
| try: | ||
| import dynesty.internal_samplers # noqa | ||
|
|
||
| return True | ||
| except ImportError: | ||
| return False | ||
|
|
||
| def __init__( | ||
| self, | ||
| likelihood, | ||
|
|
@@ -261,7 +270,20 @@ def sampler_function_kwargs(self): | |
|
|
||
| @property | ||
| def sampler_init_kwargs(self): | ||
| return {key: self.kwargs[key] for key in self._dynesty_init_kwargs} | ||
| kwargs = {key: self.kwargs[key] for key in self._dynesty_init_kwargs} | ||
| if self.new_dynesty_api: | ||
| from . import dynesty3_utils as dynesty_utils | ||
|
|
||
| if kwargs["sample"] == "act-walk": | ||
| kwargs["sample"] = dynesty_utils.ACTTrackingEnsembleWalk(**kwargs) | ||
| kwargs["bound"] = "none" | ||
| elif kwargs["sample"] == "acceptance-walk": | ||
| kwargs["sample"] = dynesty_utils.EnsembleWalkSampler(**kwargs) | ||
| kwargs["bound"] = "none" | ||
| elif kwargs["sample"] == "rwalk": | ||
| kwargs["sample"] = dynesty_utils.AcceptanceTrackingRWalk(**kwargs) | ||
| kwargs["bound"] = "none" | ||
| return kwargs | ||
|
|
||
| def _translate_kwargs(self, kwargs): | ||
| kwargs = super()._translate_kwargs(kwargs) | ||
|
|
@@ -429,7 +451,7 @@ def nlive(self): | |
|
|
||
| @property | ||
| def sampler_init(self): | ||
| from dynesty import NestedSampler | ||
| from dynesty.dynesty import NestedSampler | ||
|
|
||
| return NestedSampler | ||
|
|
||
|
|
@@ -450,6 +472,9 @@ def _set_sampling_method(self): | |
| Additionally, some combinations of bound/sample/proposals are not | ||
| compatible and so we either warn the user or raise an error. | ||
| """ | ||
| if self.new_dynesty_api: | ||
| return | ||
|
|
||
| import dynesty | ||
|
|
||
| _set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept)) | ||
|
|
@@ -595,6 +620,10 @@ def _setup_pool(self): | |
| more times than we have processes. | ||
| """ | ||
| super(Dynesty, self)._setup_pool() | ||
|
|
||
| if self.new_dynesty_api: | ||
| return | ||
|
|
||
| if self.pool is not None: | ||
| args = ( | ||
| [(self.nact, self.maxmcmc, self.proposals, self.naccept)] | ||
|
|
@@ -793,8 +822,17 @@ def read_saved_state(self, continuing=False): | |
| if continuing: | ||
| self._remove_live() | ||
| self.sampler.nqueue = -1 | ||
| self.start_time = self.sampler.kwargs.pop("start_time") | ||
| self.sampling_time = self.sampler.kwargs.pop("sampling_time") | ||
| if hasattr(self.sampler, "_bilby_metadata"): | ||
| extras = self.sampler._bilby_metadata | ||
| elif hasattr(self.sampler, "kwargs"): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused here. Why can the sampler have either metadata or kwargs?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the new API the dynesty sampler classes don't have a |
||
| extras = self.sampler.kwargs | ||
| else: | ||
| raise AttributeError( | ||
| "Loaded sampler doesn't contain timing info, " | ||
| "the checkpoint is probably corrupted." | ||
| ) | ||
| self.start_time = extras.pop("start_time") | ||
| self.sampling_time = extras.pop("sampling_time") | ||
| self.sampler.queue_size = self.kwargs["queue_size"] | ||
| self.sampler.pool = self.pool | ||
| if self.pool is not None: | ||
|
|
@@ -835,8 +873,10 @@ def write_current_state(self): | |
| check_directory_exists_and_if_not_mkdir(self.outdir) | ||
| if hasattr(self, "start_time"): | ||
| self._update_sampling_time() | ||
| self.sampler.kwargs["sampling_time"] = self.sampling_time | ||
| self.sampler.kwargs["start_time"] = self.start_time | ||
| self.sampler._bilby_metadata = dict( | ||
| sampling_time=self.sampling_time, | ||
| start_time=self.start_time, | ||
| ) | ||
| self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version) | ||
| self.sampler.pool = None | ||
| self.sampler.M = map | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.