Skip to content

Commit

Permalink
Update to handle changed signature for `blackjax.run_inference_algori…
Browse files Browse the repository at this point in the history
…thm`.

PiperOrigin-RevId: 662494836
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Aug 13, 2024
1 parent e46e7f4 commit ea84b6d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
-->

## [0.1.13] - 2024-08-13

* Prepare for more blackjax API changes.

## [0.1.13] - 2024-07-10

* Prepare for blackjax API change.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.13'
__version__ = '0.1.14'

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
Expand Down
8 changes: 6 additions & 2 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,17 @@ def _blackjax_inference(

algorithm_kwargs = kwargs[_convert_algorithm(algorithm)] | adapt_parameters
inference_algorithm = algorithm(**algorithm_kwargs)
_, states, infos = blackjax.util.run_inference_algorithm(
# This is protecting against a change in blackjax where the
# return from `run_inference_algorithm` changes from
# `_, states, infos` to `_, (states, infos)`. This one weird
# trick handles both cases.
_, *states_and_infos = blackjax.util.run_inference_algorithm(
rng_key=seed,
inference_algorithm=inference_algorithm,
num_steps=num_draws,
progress_bar=False,
**{_INFERENCE_KWARG: adapt_state})
return states, infos
return states_and_infos


def _blackjax_inference_loop(
Expand Down

0 comments on commit ea84b6d

Please sign in to comment.