diff --git a/CHANGELOG.md b/CHANGELOG.md index 1098ef0..98b10ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,12 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.1.11] - 2024-04-08 + +### Update for new version of flowMC +### Unpin scipy +### Add pymc as dependency + ## [0.1.10] - 2024-04-03 ### Keep Python 3.9 compatibility. diff --git a/bayeux/__init__.py b/bayeux/__init__.py index b2af432..231bcfa 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.10' +__version__ = '0.1.11' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/mcmc/flowmc.py b/bayeux/_src/mcmc/flowmc.py index b5376d5..f3bcd8f 100644 --- a/bayeux/_src/mcmc/flowmc.py +++ b/bayeux/_src/mcmc/flowmc.py @@ -31,9 +31,9 @@ from bayeux._src import shared from flowMC.nfmodel import realNVP from flowMC.nfmodel import rqSpline -from flowMC.sampler import HMC -from flowMC.sampler import MALA -from flowMC.sampler import Sampler +from flowMC.proposal import HMC +from flowMC.proposal import MALA +from flowMC import Sampler import jax import jax.numpy as jnp @@ -98,11 +98,8 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): sampler_kwargs.update( {k: defaults[k] for k in sampler_required if k in defaults}) sampler_required = sampler_required - sampler_kwargs.keys() - if "params" in sampler_required: - sampler_kwargs["params"] = defaults - else: - sampler_kwargs["params"] = sampler_kwargs["params"] | defaults - + sampler_kwargs.update( + {k: defaults[k] for k in sampler_kwargs if k in defaults}) sampler_required = sampler_required - sampler_kwargs.keys() if sampler_required: @@ -146,7 +143,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs): sampler_kwargs.update( {k: defaults[k] for k in sampler_required if k in defaults}) sampler_required = (sampler_required - - {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) + {"nf_model", "local_sampler", "rng_key", "kwargs"}) sampler_required = sampler_required - sampler_kwargs.keys() if sampler_required: @@ -208,16 +205,11 @@ def __call__(self, seed, **kwargs): nf_model = _NF_MODELS[self.nf_model] local_sampler = _LOCAL_SAMPLERS[self.local_sampler] - rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(seed, 3) - rng_keys_mcmc = jax.random.split(rng_key_mcmc, num_chains) - rng_keys_nf, init_rng_keys_nf = jax.random.split(rng_key_nf, 2) - model = nf_model(key=nf_key, **kwargs[nf_model]) local_sampler = local_sampler(**kwargs[local_sampler]) sampler = Sampler.Sampler nf_sampler = sampler( - rng_key_set=( - rng_key_init, rng_keys_mcmc, rng_keys_nf, init_rng_keys_nf), + rng_key=seed, local_sampler=local_sampler, nf_model=model, **kwargs[sampler]) diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index 7fbb4fb..e1fff71 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -116,13 +116,13 @@ def test_return_pytree_flowmc(): def test_samplers(method): # flowMC samplers are broken for 0 or 1 dimensions, so just test # everything on 2 dimensions for now. - model = bx.Model(log_density=lambda pt: jnp.sum(-pt["x"]**2), - test_point={"x": jnp.ones((1, 2))}) + model = bx.Model(log_density=lambda pt: -pt["x"]**2, + test_point={"x": 1.}) sampler = getattr(model.mcmc, method) seed = jax.random.PRNGKey(0) assert sampler.debug(seed=seed, verbosity=0) idata = sampler(seed=seed) - if method == "blackjax_hmc": + if method.endswith("hmc"): assert max_rhat(idata) < 1.2 else: assert max_rhat(idata) < 1.1 diff --git a/pyproject.toml b/pyproject.toml index 69757ea..e4739cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,10 @@ dependencies = [ "optax", "optimistix", "blackjax", - "flowmc", + "flowmc>=0.3.0", "numpyro", "jaxopt", - "scipy<1.13", # https://github.com/arviz-devs/arviz/issues/2336 + "pymc", ] # `version` is automatically set by flit to use `bayeux.__version__` @@ -44,7 +44,6 @@ dev = [ "pytest-xdist", "pylint>=2.6.0", "pyink", - "pymc", ]