Skip to content

Commit

Permalink
Add flowMC samplers and release a new version.
Browse files Browse the repository at this point in the history
These currently only work for more than 1 dimension, and may require some further tuning.

PiperOrigin-RevId: 603427758
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Feb 1, 2024
1 parent 721de27 commit a77a136
Show file tree
Hide file tree
Showing 28 changed files with 321 additions and 31 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.1.6] - 2024-02-01

### Add samplers from flowMC

## [0.1.5] - 2024-01-12

### Bugfix for PyMC models
Expand All @@ -47,6 +51,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
### Initial release

[Unreleased]: https://github.com/jax-ml/bayeux/compare/v0.1.5...HEAD
[0.1.6]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.6
[0.1.5]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.5
[0.1.4]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.4
[0.1.3]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.3
Expand Down
4 changes: 2 additions & 2 deletions bayeux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.5'
__version__ = '0.1.6'

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/debug.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/initialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -312,8 +312,8 @@ def get_algorithm_kwargs(algorithm, log_density, kwargs):
algorithm_kwargs, algorithm_required = shared.get_default_signature(algorithm)
kwargs_with_defaults = {
"logdensity_fn": log_density,
"step_size": 0.01,
"num_integration_steps": 8,
"step_size": 0.5,
"num_integration_steps": 16,
} | kwargs
algorithm_kwargs.update(
{
Expand Down
267 changes: 267 additions & 0 deletions bayeux/_src/mcmc/flowmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""flowMC specific code."""
import arviz as az
from bayeux._src import shared
from flowMC.nfmodel import realNVP
from flowMC.nfmodel import rqSpline
from flowMC.sampler import HMC
from flowMC.sampler import MALA
from flowMC.sampler import Sampler
import jax
import jax.numpy as jnp


_NF_MODELS = {
"real_nvp": realNVP.RealNVP,
"masked_coupling_rq_spline": rqSpline.MaskedCouplingRQSpline,
}

_LOCAL_SAMPLERS = {"mala": MALA.MALA, "hmc": HMC.HMC}


def get_nf_model_kwargs(nf_model, n_features, kwargs):
"""Sets defaults and merges user-provided adaptation keywords."""
nf_model_kwargs, nf_model_required = shared.get_default_signature(
nf_model)
nf_model_kwargs.update(
{k: kwargs[k] for k in nf_model_kwargs if k in kwargs})
nf_model_kwargs.update(
{k: kwargs[k] for k in nf_model_required if k in kwargs})
nf_model_kwargs.setdefault("n_features", n_features)
nf_model_required.remove("key")
nf_model_required.remove("kwargs")
nf_model_required = nf_model_required - nf_model_kwargs.keys()

defaults = {
# RealNVP kwargs
"n_hidden": 100,
"n_layer": 10,
# MaskedCouplingRQSpline kwargs
"n_layers": 4,
"num_bins": 8,
"hidden_size": [64, 64],
"spline_range": (-10.0, 10.0),
}
for key, value in defaults.items():
if key in nf_model_required:
nf_model_kwargs[key] = value

nf_model_required = nf_model_required - nf_model_kwargs.keys()

if nf_model_required:
raise ValueError(
"Unexpected required arguments: "
f"{','.join(nf_model_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)
return nf_model_kwargs


def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs):
"""Sets defaults and merges user-provided adaptation keywords."""

kwargs["logpdf"] = log_density
sampler_kwargs, sampler_required = shared.get_default_signature(
local_sampler)
sampler_kwargs.setdefault("jit", True)
sampler_kwargs.update(
{k: kwargs[k] for k in sampler_required if k in kwargs})
sampler_required = sampler_required - sampler_kwargs.keys()

defaults = {
# HMC kwargs
"condition_matrix": jnp.eye(n_features),
"n_leapfrog": 10,
# Both
"step_size": 0.1,
}
if "params" in sampler_required:
sampler_kwargs["params"] = defaults
else:
sampler_kwargs["params"] = sampler_kwargs["params"] | defaults

sampler_required = sampler_required - sampler_kwargs.keys()

if sampler_required:
raise ValueError(
"Unexpected required arguments: "
f"{','.join(sampler_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)
return sampler_kwargs


def get_sampler_kwargs(sampler, n_features, kwargs):
"""Sets defaults and merges user-provided adaptation keywords."""
sampler_kwargs, sampler_required = shared.get_default_signature(sampler)
sampler_kwargs.update(
{k: kwargs[k] for k in sampler_required if k in kwargs})
sampler_kwargs.setdefault("data", {})
sampler_kwargs.setdefault("n_dim", n_features)
sampler_required = (sampler_required -
{"nf_model", "local_sampler", "rng_key_set", "kwargs"})
sampler_required = sampler_required - sampler_kwargs.keys()

defaults = {
"n_loop_training": 5,
"n_loop_production": 5,
"n_local_steps": 50,
"n_global_steps": 50,
"n_chains": 20,
"n_epochs": 30,
"learning_rate": 0.01,
"max_samples": 10_000,
"momentum": 0.9,
"batch_size": 10_000,
"use_global": True,
"global_sampler": None,
"logging": True,
"keep_quantile": 0.,
"local_autotune": None,
"train_thinning": 1,
"output_thinning": 1,
"n_sample_max": 10_000,
"precompile": False,
"verbose": False}
for key, value in defaults.items():
if key not in sampler_kwargs:
sampler_kwargs[key] = value

sampler_required = sampler_required - sampler_kwargs.keys()

if sampler_required:
raise ValueError(
"Unexpected required arguments: "
f"{','.join(sampler_required)}. Probably file a bug, but "
"you can try to manually supply them as keywords."
)
return sampler_kwargs


class _FlowMCSampler(shared.Base):
"""Base class for flowmc samplers."""
name: str = ""
nf_model: str = ""
local_sampler: str = ""

def _get_aux(self):
flat, unflatten = jax.flatten_util.ravel_pytree(self.test_point)

@jax.vmap
def flatten(pytree):
return jax.flatten_util.ravel_pytree(pytree)[0]

constrained_log_density = self.constrained_log_density()
def log_density(x, _):
return constrained_log_density(unflatten(x)).squeeze()

return log_density, flatten, unflatten, flat.shape[0]

def get_kwargs(self, **kwargs):
nf_model = _NF_MODELS[self.nf_model]
local_sampler = _LOCAL_SAMPLERS[self.local_sampler]
log_density, flatten, unflatten, n_features = self._get_aux()

nf_model_kwargs = get_nf_model_kwargs(nf_model, n_features, kwargs)
local_sampler_kwargs = get_local_sampler_kwargs(
local_sampler, log_density, n_features, kwargs)
sampler = Sampler.Sampler
sampler_kwargs = get_sampler_kwargs(sampler, n_features, kwargs)
extra_parameters = {"flatten": flatten,
"unflatten": unflatten,
"num_chains": sampler_kwargs["n_chains"],
"return_pytree": kwargs.get("return_pytree", False)}

return {nf_model: nf_model_kwargs,
local_sampler: local_sampler_kwargs,
sampler: sampler_kwargs,
"extra_parameters": extra_parameters}

def __call__(self, seed, **kwargs):
kwargs = self.get_kwargs(**kwargs)
extra_parameters = kwargs["extra_parameters"]
num_chains = extra_parameters["num_chains"]
init_key, nf_key, seed = jax.random.split(seed, 3)
initial_state = self.get_initial_state(
init_key, num_chains=num_chains)
initial_state = extra_parameters["flatten"](initial_state)
nf_model = _NF_MODELS[self.nf_model]
local_sampler = _LOCAL_SAMPLERS[self.local_sampler]

rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(seed, 3)
rng_keys_mcmc = jax.random.split(rng_key_mcmc, num_chains)
rng_keys_nf, init_rng_keys_nf = jax.random.split(rng_key_nf, 2)

model = nf_model(key=nf_key, **kwargs[nf_model])
local_sampler = local_sampler(**kwargs[local_sampler])
sampler = Sampler.Sampler
nf_sampler = sampler(
rng_key_set=(
rng_key_init, rng_keys_mcmc, rng_keys_nf, init_rng_keys_nf),
local_sampler=local_sampler,
nf_model=model,
**kwargs[sampler])
nf_sampler.sample(initial_state, {})
chains, *_ = nf_sampler.get_sampler_state().values()

unflatten = jax.vmap(jax.vmap(extra_parameters["unflatten"]))
pytree = self.transform_fn(unflatten(chains))
if extra_parameters["return_pytree"]:
return pytree
else:
if hasattr(pytree, "_asdict"):
pytree = pytree._asdict()
elif not isinstance(pytree, dict):
pytree = {"var0": pytree}
return az.from_dict(posterior=pytree)


class RealNVPMALA(_FlowMCSampler):
name = "flowmc_realnvp_mala"
nf_model = "real_nvp"
local_sampler = "mala"


class RealNVPHMC(_FlowMCSampler):
name = "flowmc_realnvp_hmc"
nf_model = "real_nvp"
local_sampler = "hmc"


class MaskedCouplingRQSplineMALA(_FlowMCSampler):
name = "flowmc_rqspline_mala"
nf_model = "masked_coupling_rq_spline"
local_sampler = "mala"


class MaskedCouplingRQSplineHMC(_FlowMCSampler):
name = "flowmc_rqspline_hmc"
nf_model = "masked_coupling_rq_spline"
local_sampler = "hmc"
2 changes: 1 addition & 1 deletion bayeux/_src/mcmc/numpyro.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/optimize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/optimize/jaxopt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/optimize/optax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/optimize/optimistix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/optimize/shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion bayeux/_src/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The bayeux Authors.
# Copyright 2024 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit a77a136

Please sign in to comment.