Skip to content

Commit

Permalink
Merge pull request #41 from ColCarroll:update-versions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622846616
  • Loading branch information
The bayeux Authors committed Apr 8, 2024
2 parents b51d440 + 7b3df3f commit 2bed83c
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 22 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.1.11] - 2024-04-08

### Update for new version of flowMC
### Unpin scipy
### Add pymc as dependency

## [0.1.10] - 2024-04-03

### Keep Python 3.9 compatibility.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.10'
__version__ = '0.1.11'

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
Expand Down
22 changes: 7 additions & 15 deletions bayeux/_src/mcmc/flowmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
"""flowMC specific code."""
import arviz as az
from bayeux._src import shared
from flowMC import Sampler
from flowMC.nfmodel import realNVP
from flowMC.nfmodel import rqSpline
from flowMC.sampler import HMC
from flowMC.sampler import MALA
from flowMC.sampler import Sampler
from flowMC.proposal import HMC
from flowMC.proposal import MALA
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -98,11 +98,8 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs):
sampler_kwargs.update(
{k: defaults[k] for k in sampler_required if k in defaults})
sampler_required = sampler_required - sampler_kwargs.keys()
if "params" in sampler_required:
sampler_kwargs["params"] = defaults
else:
sampler_kwargs["params"] = sampler_kwargs["params"] | defaults

sampler_kwargs.update(
{k: defaults[k] for k in sampler_kwargs if k in defaults})
sampler_required = sampler_required - sampler_kwargs.keys()

if sampler_required:
Expand Down Expand Up @@ -146,7 +143,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs):
sampler_kwargs.update(
{k: defaults[k] for k in sampler_required if k in defaults})
sampler_required = (sampler_required -
{"nf_model", "local_sampler", "rng_key_set", "kwargs"})
{"nf_model", "local_sampler", "rng_key", "kwargs"})
sampler_required = sampler_required - sampler_kwargs.keys()

if sampler_required:
Expand Down Expand Up @@ -208,16 +205,11 @@ def __call__(self, seed, **kwargs):
nf_model = _NF_MODELS[self.nf_model]
local_sampler = _LOCAL_SAMPLERS[self.local_sampler]

rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(seed, 3)
rng_keys_mcmc = jax.random.split(rng_key_mcmc, num_chains)
rng_keys_nf, init_rng_keys_nf = jax.random.split(rng_key_nf, 2)

model = nf_model(key=nf_key, **kwargs[nf_model])
local_sampler = local_sampler(**kwargs[local_sampler])
sampler = Sampler.Sampler
nf_sampler = sampler(
rng_key_set=(
rng_key_init, rng_keys_mcmc, rng_keys_nf, init_rng_keys_nf),
rng_key=seed,
local_sampler=local_sampler,
nf_model=model,
**kwargs[sampler])
Expand Down
6 changes: 3 additions & 3 deletions bayeux/tests/mcmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def test_return_pytree_flowmc():
def test_samplers(method):
# flowMC samplers are broken for 0 or 1 dimensions, so just test
# everything on 2 dimensions for now.
model = bx.Model(log_density=lambda pt: jnp.sum(-pt["x"]**2),
test_point={"x": jnp.ones((1, 2))})
model = bx.Model(log_density=lambda pt: -pt["x"]**2,
test_point={"x": 1.})
sampler = getattr(model.mcmc, method)
seed = jax.random.PRNGKey(0)
assert sampler.debug(seed=seed, verbosity=0)
idata = sampler(seed=seed)
if method == "blackjax_hmc":
if method.endswith("hmc"):
assert max_rhat(idata) < 1.2
else:
assert max_rhat(idata) < 1.1
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ dependencies = [
"optax",
"optimistix",
"blackjax",
"flowmc",
"flowmc>=0.3.0",
"numpyro",
"jaxopt",
"scipy<1.13", # https://github.com/arviz-devs/arviz/issues/2336
"pymc",
]

# `version` is automatically set by flit to use `bayeux.__version__`
Expand All @@ -44,7 +44,6 @@ dev = [
"pytest-xdist",
"pylint>=2.6.0",
"pyink",
"pymc",
]


Expand Down

0 comments on commit 2bed83c

Please sign in to comment.