diff --git a/flowjax/bijections/planar.py b/flowjax/bijections/planar.py index f5f7bec1..f6c39acd 100644 --- a/flowjax/bijections/planar.py +++ b/flowjax/bijections/planar.py @@ -4,14 +4,15 @@ """ from collections.abc import Callable -from typing import ClassVar +from functools import partial +from typing import ClassVar, Literal import equinox as eqx import jax.numpy as jnp import jax.random as jr -from jax.nn import softplus +from jax import nn from jax.numpy.linalg import norm -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, Float, PRNGKeyArray from flowjax.bijections.bijection import AbstractBijection @@ -19,15 +20,24 @@ class Planar(AbstractBijection): r"""Planar bijection as used by https://arxiv.org/pdf/1505.05770.pdf. - Uses the transformation :math:`y + u \cdot \text{tanh}(w \cdot x + b)`, where - :math:`u \in \mathbb{R}^D, \ w \in \mathbb{R}^D` and :math:`b \in \mathbb{R}`. In - the unconditional case, :math:`w`, :math:`u` and :math:`b` are learned directly. - In the conditional case they are parameterised by an MLP. + Uses the transformation + + .. math:: + + \boldsymbol{y}=\boldsymbol{x} + + \boldsymbol{u} \cdot \text{tanh}(\boldsymbol{w}^T \boldsymbol{x} + b) + + where :math:`\boldsymbol{u} \in \mathbb{R}^D, \ \boldsymbol{w} \in \mathbb{R}^D` + and :math:`b \in \mathbb{R}`. In the unconditional case, the (unbounded) parameters + are learned directly. In the unconditional case they are parameterised by an MLP. Args: key: Jax random seed. dim: Dimension of the bijection. cond_dim: Dimension of extra conditioning variables. Defaults to None. + negative_slope: A positive float. If provided, then a leaky relu activation + (with the corresponding negative slope) is used instead of tanh. This also + provides the advantage that the bijection can be inverted analytically. **mlp_kwargs: Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None. """ @@ -36,6 +46,7 @@ class Planar(AbstractBijection): cond_shape: tuple[int, ...] | None conditioner: Callable | None params: Array | None + negative_slope: float | None def __init__( self, @@ -43,6 +54,7 @@ def __init__( *, dim: int, cond_dim: int | None = None, + negative_slope: float | None = None, **mlp_kwargs, ): self.shape = (dim,) @@ -56,6 +68,8 @@ def __init__( self.conditioner = eqx.nn.MLP(cond_dim, 2 * dim + 1, **mlp_kwargs, key=key) self.cond_shape = (cond_dim,) + self.negative_slope = negative_slope + def transform(self, x, condition=None): return self.get_planar(condition).transform(x) @@ -77,14 +91,14 @@ def get_planar(self, condition=None): dim = self.shape[0] assert params is not None w, u, bias = params[:dim], params[dim : 2 * dim], params[-1] - return _UnconditionalPlanar(w, u, bias) + return _UnconditionalPlanar(w, u, bias, self.negative_slope) class _UnconditionalPlanar(AbstractBijection): """Unconditional planar bijection, used in Planar. Note act_scale (u in the paper) is unconstrained and the constraint to ensure - invertiblitiy is applied in the ``get_act_scale``. + invertiblitiy is applied in ``get_act_scale``. """ shape: tuple[int, ...] @@ -92,21 +106,44 @@ class _UnconditionalPlanar(AbstractBijection): weight: Array _act_scale: Array bias: Array + activation: Literal["tanh"] | Literal["leaky_relu"] + activation_fn: Callable + negative_slope: float | None - def __init__(self, weight, act_scale, bias): + def __init__( + self, + weight: Float[Array, " dim"], + act_scale: Float[Array, " dim"], + bias: Float[Array, " "], + negative_slope: float | None = None, + ): self.weight = weight - self._act_scale = act_scale self.bias = bias self.shape = weight.shape + self.negative_slope = negative_slope + self._act_scale = act_scale + + if negative_slope is None: + self.activation = "tanh" + self.activation_fn = jnp.tanh + else: + if negative_slope <= 0: + raise ValueError("The negative slope value should be >0.") + self.activation = "leaky_relu" + self.activation_fn = partial(nn.leaky_relu, negative_slope=negative_slope) def transform(self, x, condition=None): - return x + self.get_act_scale() * jnp.tanh(self.weight @ x + self.bias) + u = self.get_act_scale() + return x + u * self.activation_fn(self.weight @ x + self.bias) def transform_and_log_det(self, x, condition=None): u = self.get_act_scale() - act = jnp.tanh(x @ self.weight + self.bias) + act = self.activation_fn(x @ self.weight + self.bias) y = x + u * act - psi = (1 - act**2) * self.weight + if self.activation == "leaky_relu": + psi = jnp.where(act < 0, self.negative_slope, 1) * self.weight + else: + psi = (1 - act**2) * self.weight log_det = jnp.log(jnp.abs(1 + u @ psi)) return y, log_det @@ -116,15 +153,38 @@ def get_act_scale(self): See appendix A1 in https://arxiv.org/pdf/1505.05770.pdf. """ wtu = self._act_scale @ self.weight - m_wtu = -1 + jnp.log(1 + softplus(wtu)) + m_wtu = -1 + jnp.log(1 + nn.softplus(wtu)) return self._act_scale + (m_wtu - wtu) * self.weight / norm(self.weight) ** 2 def inverse(self, y, condition=None): - raise NotImplementedError( - "The inverse planar transformation is not implemented.", - ) + if self.activation != "leaky_relu": + raise NotImplementedError( + "The inverse planar transformation is only implemented with the leaky " + "relu activation function.", + ) + return self.inverse_and_log_det(y, condition)[0] def inverse_and_log_det(self, y, condition=None): - raise NotImplementedError( - "The inverse planar transformation is not implemented.", - ) + if self.activation != "leaky_relu": + raise NotImplementedError( + "The inverse planar transformation is only implemented with the leaky " + "relu activation function.", + ) + # Expanding explanation as the inverse is not in the original paper. + # The derivation steps for the inversion are: + # 1. Let z = w^Tx+b + # 2. We want x=y-uσ(z), where σ is the leaky relu function. + # 3. Sub x=y-uσ(z) into z = w^Tx+b, + # 4. Solve for z, which gives z = (w^Ty+b)/(1+w^Tus), where s is the slope + # σ'(z), i.e. s=1 if z>=0 and s=negative_slope otherwise. To find the + # slope, it is sufficient to check the sign of the numerator w^Ty+b, rather + # than z, as the denominator is constrained to be positive. + # 5. Compute inverse using x=y-uσ(z) + + numerator = self.weight @ y + self.bias + relu_slope = jnp.where(numerator < 0, self.negative_slope, 1) + us = self.get_act_scale() * relu_slope + denominator = 1 + self.weight @ us + log_det = -jnp.log(jnp.abs(1 + us @ self.weight)) + x = y - us * (numerator / denominator) + return x, log_det diff --git a/flowjax/flows.py b/flowjax/flows.py index 615b4af0..2c8a968a 100644 --- a/flowjax/flows.py +++ b/flowjax/flows.py @@ -226,6 +226,7 @@ def planar_flow( cond_dim: int | None = None, flow_layers: int = 8, invert: bool = True, + negative_slope: float | None = None, **mlp_kwargs, ) -> Transformed: """Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf. @@ -241,6 +242,9 @@ def planar_flow( invert: Whether to invert the bijection. Broadly, True will prioritise a faster `inverse` methods, leading to faster `log_prob`, False will prioritise faster `transform` methods, leading to faster `sample`. Defaults to True. + negative_slope: A positive float. If provided, then a leaky relu activation + (with the corresponding negative slope) is used instead of tanh. This also + provides the advantage that the bijection can be inverted analytically. **mlp_kwargs: Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None. """ @@ -251,6 +255,7 @@ def make_layer(key): # Planar layer + permutation bij_key, dim=base_dist.shape[-1], cond_dim=cond_dim, + negative_slope=negative_slope, **mlp_kwargs, ) return _add_default_permute(bijection, base_dist.shape[-1], perm_key) diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 3077485d..887fa8a7 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -35,6 +35,7 @@ TriangularAffine, Vmap, ) +from flowjax.bijections.planar import _UnconditionalPlanar DIM = 3 COND_DIM = 2 @@ -171,6 +172,16 @@ [Affine(jr.uniform(k, (1, 2, 3))) for k in jr.split(KEY, 3)], axis=-1, ), + "_UnconditionalPlanar (leaky_relu +ve bias)": lambda: _UnconditionalPlanar( + *jnp.split(jr.normal(KEY, (8,)), 2), + bias=jnp.array(100.0), # leads to evaluation in +ve relu portion + negative_slope=0.1, + ), + "_UnconditionalPlanar (leaky_relu -ve bias)": lambda: _UnconditionalPlanar( + *jnp.split(jr.normal(KEY, (8,)), 2), + bias=-jnp.array(100.0), # leads to evaluation in -ve relu portion + negative_slope=0.1, + ), "Planar": lambda: Planar( KEY, dim=DIM, @@ -209,7 +220,7 @@ def test_transform_inverse(bijection_name): y = bijection.transform(x, cond) try: x_reconstructed = bijection.inverse(y, cond) - assert x == pytest.approx(x_reconstructed, abs=1e-4) + assert x_reconstructed == pytest.approx(x, abs=1e-4) except NotImplementedError: pass