Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample contrastive #167

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion flowjax/train/data_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,24 @@ def fit_to_data(
# Train epoch
batch_losses = []
for batch in zip(*get_batches(train_data, batch_size), strict=True):
key, subkey = jr.split(key)
params, opt_state, loss_i = step(
params,
static,
*batch,
optimizer=optimizer,
opt_state=opt_state,
loss_fn=loss_fn,
key=subkey,
)
batch_losses.append(loss_i)
losses["train"].append(sum(batch_losses) / len(batch_losses))

# Val epoch
batch_losses = []
for batch in zip(*get_batches(val_data, batch_size), strict=True):
loss_i = loss_fn(params, static, *batch)
key, subkey = jr.split(key)
loss_i = loss_fn(params, static, *batch, key=subkey)
batch_losses.append(loss_i)
losses["val"].append(sum(batch_losses) / len(batch_losses))

Expand Down
40 changes: 28 additions & 12 deletions flowjax/train/losses.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Common loss functions for training normalizing flows.

The loss functions are callables, with the first two arguments being the partitioned
distribution (see ``equinox.partition``).
In order to be compatible with ``fit_to_data``, the loss function arguments must match
``(params, static, x, condition, key)``, where ``params`` and ``static`` are the
partitioned model (see ``equinox.partition``).

For ``fit_to_variational_target``, the loss function signature must match
``(params, static, key)``.
"""

from collections.abc import Callable

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
from jax.lax import stop_gradient
from jax.scipy.special import logsumexp
Expand All @@ -30,8 +35,9 @@ def __call__(
static: AbstractDistribution,
x: Array,
condition: Array | None = None,
key: PRNGKeyArray | None = None,
) -> Float[Array, ""]:
"""Compute the loss."""
"""Compute the loss. Key is ignored (for consistency of API)."""
dist = unwrap(eqx.combine(params, static))
return -dist.log_prob(x, condition).mean()

Expand All @@ -52,7 +58,7 @@ class ContrastiveLoss:
prior: The prior distribution over x (the target
variable).
n_contrastive: The number of contrastive samples/atoms to use when
computing the loss.
computing the loss. Must be less than ``batch_size``.

References:
- https://arxiv.org/abs/1905.07488
Expand All @@ -69,30 +75,40 @@ def __call__(
params: AbstractDistribution,
static: AbstractDistribution,
x: Float[Array, "..."],
condition: Array | None = None,
condition: Array | None,
key: PRNGKeyArray,
) -> Float[Array, ""]:
"""Compute the loss."""
if x.shape[0] <= self.n_contrastive:
raise ValueError(
f"Number of contrastive samples {self.n_contrastive} must be less than "
f"the size of x {x.shape}.",
)

dist = unwrap(eqx.combine(params, static))

def single_x_loss(x_i, condition_i, idx):
def single_x_loss(x_i, condition_i, contrastive_idxs):
positive_logit = dist.log_prob(x_i, condition_i) - self.prior.log_prob(x_i)
contrastive = jnp.delete(x, idx, assume_unique_indices=True, axis=0)[
: self.n_contrastive
]
contrastive = x[contrastive_idxs]
contrastive_logits = dist.log_prob(
contrastive, condition_i
) - self.prior.log_prob(contrastive)
normalizer = logsumexp(jnp.append(contrastive_logits, positive_logit))
return -(positive_logit - normalizer)

return eqx.filter_vmap(single_x_loss)(
x, condition, jnp.arange(x.shape[0], dtype=int)
).mean()
contrastive_idxs = _get_contrastive_idxs(key, x.shape[0], self.n_contrastive)
return eqx.filter_vmap(single_x_loss)(x, condition, contrastive_idxs).mean()


def _get_contrastive_idxs(key: PRNGKeyArray, batch_size: int, n_contrastive: int):

@eqx.filter_vmap
def _get_idxs(key, idx, batch_size, n_contrastive):
choices = jnp.delete(jnp.arange(batch_size), idx, assume_unique_indices=True)
return jr.choice(key, choices, (n_contrastive,), replace=False)

keys = jr.split(key, batch_size)
return _get_idxs(keys, jnp.arange(batch_size), batch_size, n_contrastive)


class ElboLoss:
Expand Down
9 changes: 8 additions & 1 deletion flowjax/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def step(
optimizer: optax.GradientTransformation,
opt_state: PyTree,
loss_fn: Callable[[PyTree, PyTree], Scalar],
**kwargs,
):
"""Carry out a training step.

Expand All @@ -30,11 +31,17 @@ def step(
opt_state: Optimizer state.
loss_fn: The loss function. This should take params and static as the first two
arguments.
**kwargs: Key word arguments passed to the loss function.

Returns:
tuple: (params, opt_state, loss_val)
"""
loss_val, grads = eqx.filter_value_and_grad(loss_fn)(params, static, *args)
loss_val, grads = eqx.filter_value_and_grad(loss_fn)(
params,
static,
*args,
**kwargs,
)
updates, opt_state = optimizer.update(grads, opt_state, params=params)
params = eqx.apply_updates(params, updates)
return params, opt_state, loss_val
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.5.0"
version = "13.0.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_train/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import jax.numpy as jnp
import jax.random as jr

from flowjax.train.losses import _get_contrastive_idxs


def test_get_contrastive_idxs():
key = jr.PRNGKey(0)
batch_size = 5

for _ in range(5):
key, subkey = jr.split(key)
idxs = _get_contrastive_idxs(subkey, batch_size=batch_size, n_contrastive=4)
for i, row in enumerate(idxs):
assert i not in row

assert jnp.all(idxs < batch_size)
Loading