diff --git a/CHANGELOG.md b/CHANGELOG.md index f843850..84951ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): --> +## [0.1.13] - 2024-08-13 + +* Prepare for more blackjax API changes. + ## [0.1.13] - 2024-07-10 * Prepare for blackjax API change. diff --git a/bayeux/__init__.py b/bayeux/__init__.py index ae1f30e..7fadd7a 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.13' +__version__ = '0.1.14' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 757cc74..9faa219 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -198,13 +198,17 @@ def _blackjax_inference( algorithm_kwargs = kwargs[_convert_algorithm(algorithm)] | adapt_parameters inference_algorithm = algorithm(**algorithm_kwargs) - _, states, infos = blackjax.util.run_inference_algorithm( + # This is protecting against a change in blackjax where the + # return from `run_inference_algorithm` changes from + # `_, states, infos` to `_, (states, infos)`. This one weird + # trick handles both cases. + _, *states_and_infos = blackjax.util.run_inference_algorithm( rng_key=seed, inference_algorithm=inference_algorithm, num_steps=num_draws, progress_bar=False, **{_INFERENCE_KWARG: adapt_state}) - return states, infos + return states_and_infos def _blackjax_inference_loop(