Skip to content

Commit

Permalink
Merge pull request #170 from danielward27/planar_inverse
Browse files Browse the repository at this point in the history
Planar inverse with leaky relu
  • Loading branch information
danielward27 committed Aug 8, 2024
2 parents 91b1c39 + fdb2d5e commit a4bcab8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 22 deletions.
102 changes: 81 additions & 21 deletions flowjax/bijections/planar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,40 @@
"""

from collections.abc import Callable
from typing import ClassVar
from functools import partial
from typing import ClassVar, Literal

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from jax.nn import softplus
from jax import nn
from jax.numpy.linalg import norm
from jaxtyping import Array, PRNGKeyArray
from jaxtyping import Array, Float, PRNGKeyArray

from flowjax.bijections.bijection import AbstractBijection


class Planar(AbstractBijection):
r"""Planar bijection as used by https://arxiv.org/pdf/1505.05770.pdf.
Uses the transformation :math:`y + u \cdot \text{tanh}(w \cdot x + b)`, where
:math:`u \in \mathbb{R}^D, \ w \in \mathbb{R}^D` and :math:`b \in \mathbb{R}`. In
the unconditional case, :math:`w`, :math:`u` and :math:`b` are learned directly.
In the conditional case they are parameterised by an MLP.
Uses the transformation
.. math::
\boldsymbol{y}=\boldsymbol{x} +
\boldsymbol{u} \cdot \text{tanh}(\boldsymbol{w}^T \boldsymbol{x} + b)
where :math:`\boldsymbol{u} \in \mathbb{R}^D, \ \boldsymbol{w} \in \mathbb{R}^D`
and :math:`b \in \mathbb{R}`. In the unconditional case, the (unbounded) parameters
are learned directly. In the unconditional case they are parameterised by an MLP.
Args:
key: Jax random seed.
dim: Dimension of the bijection.
cond_dim: Dimension of extra conditioning variables. Defaults to None.
negative_slope: A positive float. If provided, then a leaky relu activation
(with the corresponding negative slope) is used instead of tanh. This also
provides the advantage that the bijection can be inverted analytically.
**mlp_kwargs: Keyword arguments (excluding in_size and out_size) passed to
the MLP (equinox.nn.MLP). Ignored when cond_dim is None.
"""
Expand All @@ -36,13 +46,15 @@ class Planar(AbstractBijection):
cond_shape: tuple[int, ...] | None
conditioner: Callable | None
params: Array | None
negative_slope: float | None

def __init__(
self,
key: PRNGKeyArray,
*,
dim: int,
cond_dim: int | None = None,
negative_slope: float | None = None,
**mlp_kwargs,
):
self.shape = (dim,)
Expand All @@ -56,6 +68,8 @@ def __init__(
self.conditioner = eqx.nn.MLP(cond_dim, 2 * dim + 1, **mlp_kwargs, key=key)
self.cond_shape = (cond_dim,)

self.negative_slope = negative_slope

def transform(self, x, condition=None):
return self.get_planar(condition).transform(x)

Expand All @@ -77,36 +91,59 @@ def get_planar(self, condition=None):
dim = self.shape[0]
assert params is not None
w, u, bias = params[:dim], params[dim : 2 * dim], params[-1]
return _UnconditionalPlanar(w, u, bias)
return _UnconditionalPlanar(w, u, bias, self.negative_slope)


class _UnconditionalPlanar(AbstractBijection):
"""Unconditional planar bijection, used in Planar.
Note act_scale (u in the paper) is unconstrained and the constraint to ensure
invertiblitiy is applied in the ``get_act_scale``.
invertiblitiy is applied in ``get_act_scale``.
"""

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
weight: Array
_act_scale: Array
bias: Array
activation: Literal["tanh"] | Literal["leaky_relu"]
activation_fn: Callable
negative_slope: float | None

def __init__(self, weight, act_scale, bias):
def __init__(
self,
weight: Float[Array, " dim"],
act_scale: Float[Array, " dim"],
bias: Float[Array, " "],
negative_slope: float | None = None,
):
self.weight = weight
self._act_scale = act_scale
self.bias = bias
self.shape = weight.shape
self.negative_slope = negative_slope
self._act_scale = act_scale

if negative_slope is None:
self.activation = "tanh"
self.activation_fn = jnp.tanh
else:
if negative_slope <= 0:
raise ValueError("The negative slope value should be >0.")
self.activation = "leaky_relu"
self.activation_fn = partial(nn.leaky_relu, negative_slope=negative_slope)

def transform(self, x, condition=None):
return x + self.get_act_scale() * jnp.tanh(self.weight @ x + self.bias)
u = self.get_act_scale()
return x + u * self.activation_fn(self.weight @ x + self.bias)

def transform_and_log_det(self, x, condition=None):
u = self.get_act_scale()
act = jnp.tanh(x @ self.weight + self.bias)
act = self.activation_fn(x @ self.weight + self.bias)
y = x + u * act
psi = (1 - act**2) * self.weight
if self.activation == "leaky_relu":
psi = jnp.where(act < 0, self.negative_slope, 1) * self.weight
else:
psi = (1 - act**2) * self.weight
log_det = jnp.log(jnp.abs(1 + u @ psi))
return y, log_det

Expand All @@ -116,15 +153,38 @@ def get_act_scale(self):
See appendix A1 in https://arxiv.org/pdf/1505.05770.pdf.
"""
wtu = self._act_scale @ self.weight
m_wtu = -1 + jnp.log(1 + softplus(wtu))
m_wtu = -1 + jnp.log(1 + nn.softplus(wtu))
return self._act_scale + (m_wtu - wtu) * self.weight / norm(self.weight) ** 2

def inverse(self, y, condition=None):
raise NotImplementedError(
"The inverse planar transformation is not implemented.",
)
if self.activation != "leaky_relu":
raise NotImplementedError(
"The inverse planar transformation is only implemented with the leaky "
"relu activation function.",
)
return self.inverse_and_log_det(y, condition)[0]

def inverse_and_log_det(self, y, condition=None):
raise NotImplementedError(
"The inverse planar transformation is not implemented.",
)
if self.activation != "leaky_relu":
raise NotImplementedError(
"The inverse planar transformation is only implemented with the leaky "
"relu activation function.",
)
# Expanding explanation as the inverse is not in the original paper.
# The derivation steps for the inversion are:
# 1. Let z = w^Tx+b
# 2. We want x=y-uσ(z), where σ is the leaky relu function.
# 3. Sub x=y-uσ(z) into z = w^Tx+b,
# 4. Solve for z, which gives z = (w^Ty+b)/(1+w^Tus), where s is the slope
# σ'(z), i.e. s=1 if z>=0 and s=negative_slope otherwise. To find the
# slope, it is sufficient to check the sign of the numerator w^Ty+b, rather
# than z, as the denominator is constrained to be positive.
# 5. Compute inverse using x=y-uσ(z)

numerator = self.weight @ y + self.bias
relu_slope = jnp.where(numerator < 0, self.negative_slope, 1)
us = self.get_act_scale() * relu_slope
denominator = 1 + self.weight @ us
log_det = -jnp.log(jnp.abs(1 + us @ self.weight))
x = y - us * (numerator / denominator)
return x, log_det
5 changes: 5 additions & 0 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def planar_flow(
cond_dim: int | None = None,
flow_layers: int = 8,
invert: bool = True,
negative_slope: float | None = None,
**mlp_kwargs,
) -> Transformed:
"""Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf.
Expand All @@ -241,6 +242,9 @@ def planar_flow(
invert: Whether to invert the bijection. Broadly, True will prioritise a faster
`inverse` methods, leading to faster `log_prob`, False will prioritise
faster `transform` methods, leading to faster `sample`. Defaults to True.
negative_slope: A positive float. If provided, then a leaky relu activation
(with the corresponding negative slope) is used instead of tanh. This also
provides the advantage that the bijection can be inverted analytically.
**mlp_kwargs: Keyword arguments (excluding in_size and out_size) passed to
the MLP (equinox.nn.MLP). Ignored when cond_dim is None.
"""
Expand All @@ -251,6 +255,7 @@ def make_layer(key): # Planar layer + permutation
bij_key,
dim=base_dist.shape[-1],
cond_dim=cond_dim,
negative_slope=negative_slope,
**mlp_kwargs,
)
return _add_default_permute(bijection, base_dist.shape[-1], perm_key)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TriangularAffine,
Vmap,
)
from flowjax.bijections.planar import _UnconditionalPlanar

DIM = 3
COND_DIM = 2
Expand Down Expand Up @@ -171,6 +172,16 @@
[Affine(jr.uniform(k, (1, 2, 3))) for k in jr.split(KEY, 3)],
axis=-1,
),
"_UnconditionalPlanar (leaky_relu +ve bias)": lambda: _UnconditionalPlanar(
*jnp.split(jr.normal(KEY, (8,)), 2),
bias=jnp.array(100.0), # leads to evaluation in +ve relu portion
negative_slope=0.1,
),
"_UnconditionalPlanar (leaky_relu -ve bias)": lambda: _UnconditionalPlanar(
*jnp.split(jr.normal(KEY, (8,)), 2),
bias=-jnp.array(100.0), # leads to evaluation in -ve relu portion
negative_slope=0.1,
),
"Planar": lambda: Planar(
KEY,
dim=DIM,
Expand Down Expand Up @@ -209,7 +220,7 @@ def test_transform_inverse(bijection_name):
y = bijection.transform(x, cond)
try:
x_reconstructed = bijection.inverse(y, cond)
assert x == pytest.approx(x_reconstructed, abs=1e-4)
assert x_reconstructed == pytest.approx(x, abs=1e-4)
except NotImplementedError:
pass

Expand Down

0 comments on commit a4bcab8

Please sign in to comment.