Skip to content

Commit

Permalink
Merge pull request #31 from ColCarroll:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611093393
  • Loading branch information
The bayeux Authors committed Feb 28, 2024
2 parents e12fe89 + 01dc6c1 commit 131b685
Show file tree
Hide file tree
Showing 8 changed files with 477 additions and 210 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.1.9] - 2024-02-27

### Add programmatic access to algorithms

## [0.1.8] - 2024-02-14

### Add HMC and NUTS from TFP
Expand Down
210 changes: 8 additions & 202 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install bayeux-ml
```
## Quickstart

We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.
We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like [numpyro](/examples/numpyro_and_bayeux), [PyMC](/examples/pymc_and_bayeux), [TFP](/examples/tfp_and_bayeux), distrax, oryx, coix, or directly in JAX.

```python
import bayeux as bx
Expand All @@ -24,215 +24,21 @@ normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)

seed = jax.random.PRNGKey(0)
```

## Simple
Every inference algorithm in `bayeux` will (try to) run with just a seed as an argument:
seed = jax.random.key(0)

```python
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
```

An (only rarely) optional third argument to `bx.Model` is `transform_fn`, which maps a real number to the support of the distribution. The [oryx](https://github.com/jax-ml/oryx) library is used to automatically compute the inverse and Jacobian determinants for changes of variables, but the user can supply these if known.

```python
half_normal_density = bx.Model(
lambda x: -x*x,
test_point=1.,
transform_fn=jax.nn.softplus)
```

## Self descriptive

Since `bayeux` is built on top of other fantastic libraries, it tries not to get in the way of them. Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called:

```python
normal_density.optimize.jaxopt_bfgs.get_kwargs()

{jaxopt._src.bfgs.BFGS: {'value_and_grad': False,
'has_aux': False,
'maxiter': 500,
'tol': 0.001,
'stepsize': 0.0,
'linesearch': 'zoom',
'linesearch_init': 'increase',
'condition': None,
'maxls': 30,
'decrease_factor': None,
'increase_factor': 1.5,
'max_stepsize': 1.0,
'min_stepsize': 1e-06,
'implicit_diff': True,
'implicit_diff_solve': None,
'jit': True,
'unroll': 'auto',
'verbose': False},
'extra_parameters': {'chain_method': 'vectorized',
'num_particles': 8,
'num_iters': 1000,
'apply_transform': True}}
```

If you pass an argument into `.get_kwargs()`, this will also tell you what will be passed on to the actual algorithms.

```
normal_density.mcmc.blackjax_nuts.get_kwargs(
num_chains=5,
target_acceptance_rate=0.99)
{<blackjax.adaptation.window_adaptation.window_adaptation: {'is_mass_matrix_diagonal': True,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.99,
'progress_bar': False,
'algorithm': blackjax.mcmc.nuts.nuts},
blackjax.mcmc.nuts.nuts: {'max_num_doublings': 10,
'divergence_threshold': 1000,
'integrator': blackjax.mcmc.integrators.velocity_verlet,
'step_size': 0.01},
'extra_parameters': {'chain_method': 'vectorized',
'num_chains': 5,
'num_draws': 500,
'num_adapt_draws': 500,
'return_pytree': False}}
```
## Read more

A full list of available algorithms and how to call them can be seen with

```python
print(normal_density)

mcmc
.blackjax_hmc
.blackjax_nuts
.blackjax_hmc_pathfinder
.blackjax_nuts_pathfinder
.numpyro_hmc
.numpyro_nuts
optimize
.jaxopt_bfgs
.jaxopt_gradient_descent
.jaxopt_lbfgs
.jaxopt_nonlinear_cg
.optax_adabelief
.optax_adafactor
.optax_adagrad
.optax_adam
.optax_adamw
.optax_adamax
.optax_amsgrad
.optax_fromage
.optax_lamb
.optax_lion
.optax_noisy_sgd
.optax_novograd
.optax_radam
.optax_rmsprop
.optax_sgd
.optax_sm3
.optax_yogi
vi
.tfp_factored_surrogate_posterior

```

## Helpful

Algorithms come with a built-in `debug` mode that attempts to fail quickly and in a manner that might help debug problems quickly. The signature for `debug` accepts `verbosity` and `catch_exceptions` arguments, as well as a `kwargs` dictionary that the user plans to pass to the algorithm itself.

```python
normal_density.mcmc.numpyro_nuts.debug(seed=seed)

Checking test_point shape ✓
Computing test point log density ✓
Loading keyword arguments...
Checking it is possible to compute an initial state ✓
Checking initial state is has no NaN ✓
Computing initial state log density ✓
Transforming model to R^n ✓
Computing transformed state log density shape ✓
Comparing transformed log density to untransformed ✓
Computing gradients of transformed log density ✓
True
```

Here is an example of a bad model with a higher verbosity:
```python
import jax.numpy as jnp

bad_model = bx.Model(
log_density=jnp.sqrt,
test_point=-1.)

bad_model.mcmc.blackjax_nuts.debug(jax.random.PRNGKey(0),
verbosity=3, kwargs={"num_chains": 17})

Checking test_point shape ✓
Test point has shape
()
✓✓✓✓✓✓✓✓✓✓

Computing test point log density ×
Test point has log density
Array(nan, dtype=float32, weak_type=True)
××××××××××

Loading keyword arguments...
Keyword arguments are
{<function window_adaptation at 0x77feef9308b0>: {'algorithm': <class 'blackjax.mcmc.nuts.nuts'>,
'initial_step_size': 1.0,
'is_mass_matrix_diagonal': True,
'progress_bar': False,
'target_acceptance_rate': 0.8},
'extra_parameters': {'chain_method': 'vectorized',
'num_adapt_draws': 500,
'num_chains': 17,
'num_draws': 500,
'return_pytree': False},
<class 'blackjax.mcmc.nuts.nuts'>: {'divergence_threshold': 1000,
'integrator': <function velocity_verlet at 0x77feefbf4b80>,
'max_num_doublings': 10,
'step_size': 0.01}}
✓✓✓✓✓✓✓✓✓✓

Checking it is possible to compute an initial state ✓
Initial state has shape
(17,)
✓✓✓✓✓✓✓✓✓✓

Checking initial state is has no NaN ✓
No nans detected!
✓✓✓✓✓✓✓✓✓✓

Computing initial state log density ×
Initial state has log density
Array([1.2212421 , nan, nan, 1.4113309 , nan,
nan, nan, nan, nan, nan,
0.5912253 , nan, nan, nan, 0.65457666,
nan, nan], dtype=float32)
××××××××××

Transforming model to R^n ✓
Transformed state has shape
(17,)
✓✓✓✓✓✓✓✓✓✓

Computing transformed state log density shape ✓
Transformed state log density has shape
(17,)
✓✓✓✓✓✓✓✓✓✓

Computing gradients of transformed log density ×
The gradient contains NaNs! Initial gradients has shape
(17,)
××××××××××

False
```
* [Defining models](/inference)
* [Inspecting models](/inspecting)
* [Testing and debugging](/debug_mode)
* Also see `bayeux` integration with [numpyro](/examples/numpyro_and_bayeux), [PyMC](/examples/pymc_and_bayeux), and [TFP](/examples/tfp_and_bayeux)!


*This is not an officially supported Google product.*
*This is not an officially supported Google product.*
2 changes: 1 addition & 1 deletion bayeux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.8'
__version__ = '0.1.9'

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
Expand Down
20 changes: 13 additions & 7 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
class _Namespace:

def __init__(self):
self._fns = []
self.methods = []

def __repr__(self):
return "\n".join(self._fns)
return "\n".join(self.methods)

def __setclass__(self, clas, parent):
kwargs = {k: getattr(parent, k) for k in _REQUIRED_KWARGS}
setattr(self, clas.name, clas(**kwargs))
self._fns.append(clas.name)
self.methods.append(clas.name)


def is_tfp_bijector(bij):
Expand Down Expand Up @@ -100,12 +100,18 @@ def __post_init__(self):

def __repr__(self):
methods = []
for name in self._namespaces:
methods.append(name)
k = getattr(self, name)
methods.append("\t." + "\n\t.".join(str(k).split()))
for key, values in self.methods.items():
methods.append(key)
methods.append("\t." + "\n\t.".join(values))
return "\n".join(methods)

@property
def methods(self):
methods = {}
for name in self._namespaces:
methods[name] = getattr(self, name).methods
return methods

@classmethod
def from_tfp(cls, pinned_joint_distribution, initial_state=None):
log_density = pinned_joint_distribution.log_prob
Expand Down
Loading

0 comments on commit 131b685

Please sign in to comment.