Skip to content

Commit

Permalink
Merge pull request #164 from danielward27/dev
Browse files Browse the repository at this point in the history
Fix bug in ContrastiveLoss normalization and add improve numpyro transformed wrapper
  • Loading branch information
danielward27 committed Jun 6, 2024
2 parents b153b38 + b3ffba7 commit c067413
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 30 deletions.
15 changes: 7 additions & 8 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

50 changes: 49 additions & 1 deletion flowjax/experimental/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,59 @@
raise

from jaxtyping import PyTree
from numpyro.distributions import TransformedDistribution
from numpyro.distributions.constraints import (
_IndependentConstraint,
_Real,
)
from numpyro.distributions.transforms import IndependentTransform, biject_to
from numpyro.distributions.util import sum_rightmost

from flowjax.bijections import Invert


class _BetterTransformedDistribution(TransformedDistribution):
# In numpyro, the log_prob method seperately computes the inverse and the log
# jacobian of the forward transformation. This becomes inefficient (or causes
# errors) when the inverse computation and the forward log det share computations.
# This class avoids this for flowjax bijections.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def log_prob(self, value, intermediates=None):
event_dim = len(self.event_shape)
log_prob = 0.0
y = value

for i, transform in enumerate(reversed(self.transforms)):

if isinstance(transform, _BijectionToNumpyro) and intermediates is None:
# Compute inv and log det in one
inv_transform = _BijectionToNumpyro(
Invert(transform.bijection),
transform.condition,
domain=transform.inv.domain,
codomain=transform.inv.codomain,
)
x, t_log_det = inv_transform.call_with_intermediates(y)
t_log_det = -t_log_det
else:
if intermediates is None:
x = transform.inv(y)
t_inter = None
else:
x = intermediates[-i - 1][0]
t_inter = intermediates[-i - 1][1]
t_log_det = transform.log_abs_det_jacobian(x, y, t_inter)
batch_ndim = event_dim - transform.codomain.event_dim
log_prob = log_prob - sum_rightmost(t_log_det, batch_ndim)
event_dim = transform.domain.event_dim + batch_ndim
y = x

return log_prob + sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)


class _RealNdim(_IndependentConstraint):
Expand Down Expand Up @@ -158,7 +206,7 @@ def _transformed_to_numpyro(dist, condition=None):
base_dist = _DistributionToNumpyro(dist.base_dist).expand(batch_shape)

transform = _BijectionToNumpyro(dist.bijection, condition)
return numpyro.distributions.TransformedDistribution(base_dist, transform)
return _BetterTransformedDistribution(base_dist, transform)


def _get_batch_shape(condition, cond_shape):
Expand Down
37 changes: 18 additions & 19 deletions flowjax/train/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,27 @@ def __call__(
condition: Array | None = None,
) -> Float[Array, ""]:
"""Compute the loss."""
dist = unwrap(eqx.combine(params, static))
contrastive = self._get_contrastive(x)
joint_log_odds = dist.log_prob(x, condition) - self.prior.log_prob(x)
contrastive_log_odds = dist.log_prob(
contrastive,
condition,
) - self.prior.log_prob(contrastive)
contrastive_log_odds = jnp.clip(
contrastive_log_odds, -5
) # TODO Clip for stability - this maybe should reconsidered
return -(joint_log_odds - logsumexp(contrastive_log_odds, axis=0)).mean()

def _get_contrastive(self, theta):
if theta.shape[0] <= self.n_contrastive:
if x.shape[0] <= self.n_contrastive:
raise ValueError(
f"Number of contrastive samples {self.n_contrastive} must be less than "
f"the size of theta {theta.shape}.",
f"the size of x {x.shape}.",
)
# Rolling window over theta batch to create contrastive samples.
idx = jnp.arange(len(theta))[:, None] + jnp.arange(self.n_contrastive)[None, :]
contrastive = jnp.roll(theta[idx], -1, axis=0) # Ensure mismatch with condition
return jnp.swapaxes(contrastive, 0, 1) # (contrastive, batch_size, dim)
dist = unwrap(eqx.combine(params, static))

def single_x_loss(x_i, condition_i, idx):
positive_logit = dist.log_prob(x_i, condition_i) - self.prior.log_prob(x_i)
contrastive = jnp.delete(x, idx, assume_unique_indices=True, axis=0)[
: self.n_contrastive
]
contrastive_logits = dist.log_prob(
contrastive, condition_i
) - self.prior.log_prob(contrastive)
normalizer = logsumexp(jnp.append(contrastive_logits, positive_logit))
return -(positive_logit - normalizer)

return eqx.filter_vmap(single_x_loss)(
x, condition, jnp.arange(x.shape[0], dtype=int)
).mean()


class ElboLoss:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.2.0"
version = "12.3.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
38 changes: 37 additions & 1 deletion tests/test_experimental/test_numpyro.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import ClassVar

import equinox as eqx
import jax
Expand All @@ -15,7 +16,7 @@
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.optim import Adam

from flowjax.bijections import AdditiveCondition, Affine
from flowjax.bijections import AbstractBijection, AdditiveCondition, Affine, Invert
from flowjax.distributions import (
LogNormal,
Normal,
Expand Down Expand Up @@ -391,3 +392,38 @@ def model():
assert "x_base" in trace.keys()
expected_x = log_norm.bijection.transform(trace["x_base"]["value"])
assert pytest.approx(expected_x) == trace["x"]["value"]


class _ForwardOnly(AbstractBijection):
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

def transform(self, x, condition=None):
return x

def transform_and_log_det(self, x, condition=None):
return x, jnp.zeros(())

def inverse(self, y, condition=None):
raise NotImplementedError()

def inverse_and_log_det(self, y, condition=None):
raise NotImplementedError()


def test_sampling_forward_only():
dist = Transformed(
StandardNormal(),
_ForwardOnly(),
)
dist = distribution_to_numpyro(dist)
dist.sample(jr.PRNGKey(0))


def test_log_prob_inverse_only():
dist = Transformed(
StandardNormal(),
Invert(_ForwardOnly()),
)
dist = distribution_to_numpyro(dist)
dist.log_prob(0)

0 comments on commit c067413

Please sign in to comment.