Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update flowMC, unpin scipy #41

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.1.11] - 2024-04-08

### Update for new version of flowMC
### Unpin scipy
### Add pymc as dependency

## [0.1.10] - 2024-04-03

### Keep Python 3.9 compatibility.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.10'
__version__ = '0.1.11'

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
Expand Down
22 changes: 7 additions & 15 deletions bayeux/_src/mcmc/flowmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from bayeux._src import shared
from flowMC.nfmodel import realNVP
from flowMC.nfmodel import rqSpline
from flowMC.sampler import HMC
from flowMC.sampler import MALA
from flowMC.sampler import Sampler
from flowMC.proposal import HMC
from flowMC.proposal import MALA
from flowMC import Sampler
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -98,11 +98,8 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs):
sampler_kwargs.update(
{k: defaults[k] for k in sampler_required if k in defaults})
sampler_required = sampler_required - sampler_kwargs.keys()
if "params" in sampler_required:
sampler_kwargs["params"] = defaults
else:
sampler_kwargs["params"] = sampler_kwargs["params"] | defaults

sampler_kwargs.update(
{k: defaults[k] for k in sampler_kwargs if k in defaults})
sampler_required = sampler_required - sampler_kwargs.keys()

if sampler_required:
Expand Down Expand Up @@ -146,7 +143,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs):
sampler_kwargs.update(
{k: defaults[k] for k in sampler_required if k in defaults})
sampler_required = (sampler_required -
{"nf_model", "local_sampler", "rng_key_set", "kwargs"})
{"nf_model", "local_sampler", "rng_key", "kwargs"})
sampler_required = sampler_required - sampler_kwargs.keys()

if sampler_required:
Expand Down Expand Up @@ -208,16 +205,11 @@ def __call__(self, seed, **kwargs):
nf_model = _NF_MODELS[self.nf_model]
local_sampler = _LOCAL_SAMPLERS[self.local_sampler]

rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(seed, 3)
rng_keys_mcmc = jax.random.split(rng_key_mcmc, num_chains)
rng_keys_nf, init_rng_keys_nf = jax.random.split(rng_key_nf, 2)

model = nf_model(key=nf_key, **kwargs[nf_model])
local_sampler = local_sampler(**kwargs[local_sampler])
sampler = Sampler.Sampler
nf_sampler = sampler(
rng_key_set=(
rng_key_init, rng_keys_mcmc, rng_keys_nf, init_rng_keys_nf),
rng_key=seed,
local_sampler=local_sampler,
nf_model=model,
**kwargs[sampler])
Expand Down
6 changes: 3 additions & 3 deletions bayeux/tests/mcmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def test_return_pytree_flowmc():
def test_samplers(method):
# flowMC samplers are broken for 0 or 1 dimensions, so just test
# everything on 2 dimensions for now.
model = bx.Model(log_density=lambda pt: jnp.sum(-pt["x"]**2),
test_point={"x": jnp.ones((1, 2))})
model = bx.Model(log_density=lambda pt: -pt["x"]**2,
test_point={"x": 1.})
sampler = getattr(model.mcmc, method)
seed = jax.random.PRNGKey(0)
assert sampler.debug(seed=seed, verbosity=0)
idata = sampler(seed=seed)
if method == "blackjax_hmc":
if method.endswith("hmc"):
assert max_rhat(idata) < 1.2
else:
assert max_rhat(idata) < 1.1
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ dependencies = [
"optax",
"optimistix",
"blackjax",
"flowmc",
"flowmc>=0.3.0",
"numpyro",
"jaxopt",
"scipy<1.13", # https://github.com/arviz-devs/arviz/issues/2336
"pymc",
]

# `version` is automatically set by flit to use `bayeux.__version__`
Expand All @@ -44,7 +44,6 @@ dev = [
"pytest-xdist",
"pylint>=2.6.0",
"pyink",
"pymc",
]


Expand Down