Skip to content

Commit

Permalink
Merge pull request #166 from danielward27/spline_interval
Browse files Browse the repository at this point in the history
Allow uneven interval in spline
  • Loading branch information
danielward27 committed Jul 11, 2024
2 parents 19c7355 + e431b6e commit bdae1ba
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
25 changes: 14 additions & 11 deletions flowjax/bijections/rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def _real_to_increasing_on_interval(
arr: Float[Array, " dim"],
interval: float | int = 1,
interval: tuple[int | float, int | float],
softmax_adjust: float = 1e-2,
*,
pad_with_ends: bool = True,
Expand All @@ -35,10 +35,11 @@ def _real_to_increasing_on_interval(
widths = jax.nn.softmax(arr)
widths = (widths + softmax_adjust / widths.size) / (1 + softmax_adjust)
widths = widths.at[0].set(widths[0] / 2)
pos = 2 * interval * jnp.cumsum(widths) - interval
scale = interval[1] - interval[0]
pos = interval[0] + scale * jnp.cumsum(widths)

if pad_with_ends:
pos = jnp.pad(pos, pad_width=1, constant_values=(-interval, interval))
pos = jnp.pad(pos, pad_width=1, constant_values=interval)

return pos

Expand All @@ -48,7 +49,8 @@ class RationalQuadraticSpline(AbstractBijection):
Args:
knots: Number of knots.
interval: interval to transform, [-interval, interval].
interval: Interval to transform, if a scalar value, uses [-interval, interval],
if a tuple, uses [interval[0], interval[1]]
min_derivative: Minimum dervivative. Defaults to 1e-3.
softmax_adjust: Controls minimum bin width and height by rescaling softmax
output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced
Expand All @@ -57,7 +59,7 @@ class RationalQuadraticSpline(AbstractBijection):
"""

knots: int
interval: float | int
interval: tuple[int | float, int | float]
softmax_adjust: float | int
min_derivative: float
x_pos: Array | wrappers.AbstractUnwrappable[Array]
Expand All @@ -70,11 +72,12 @@ def __init__(
self,
*,
knots: int,
interval: float | int,
interval: float | int | tuple[int | float, int | float],
min_derivative: float = 1e-3,
softmax_adjust: float | int = 1e-2,
):
self.knots = knots
interval = interval if isinstance(interval, tuple) else (-interval, interval)
self.interval = interval
self.softmax_adjust = softmax_adjust
self.min_derivative = min_derivative
Expand All @@ -96,7 +99,7 @@ def __init__(
def transform(self, x, condition=None):
# Following notation from the paper
x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives
in_bounds = jnp.logical_and(x >= -self.interval, x <= self.interval)
in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1])
x_robust = jnp.where(in_bounds, x, 0) # To avoid nans
k = jnp.searchsorted(x_pos, x_robust) - 1 # k is bin number
xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k])
Expand All @@ -107,7 +110,7 @@ def transform(self, x, condition=None):
y = yk + num / den # eq. 4

# avoid numerical precision issues transforming from in -> out of bounds
y = jnp.clip(y, -self.interval, self.interval)
y = jnp.clip(y, self.interval[0], self.interval[1])
return jnp.where(in_bounds, y, x)

def transform_and_log_det(self, x, condition=None):
Expand All @@ -118,7 +121,7 @@ def transform_and_log_det(self, x, condition=None):
def inverse(self, y, condition=None):
# Following notation from the paper
x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives
in_bounds = jnp.logical_and(y >= -self.interval, y <= self.interval)
in_bounds = jnp.logical_and(y >= self.interval[0], y <= self.interval[1])
y_robust = jnp.where(in_bounds, y, 0) # To avoid nans
k = jnp.searchsorted(y_pos, y_robust) - 1
xk, xk1, yk, yk1 = x_pos[k], x_pos[k + 1], y_pos[k], y_pos[k + 1]
Expand All @@ -134,7 +137,7 @@ def inverse(self, y, condition=None):
x = xi * (xk1 - xk) + xk

# avoid numerical precision issues transforming from in -> out of bounds
x = jnp.clip(x, -self.interval, self.interval)
x = jnp.clip(x, self.interval[0], self.interval[1])
return jnp.where(in_bounds, x, y)

def inverse_and_log_det(self, y, condition=None):
Expand All @@ -146,7 +149,7 @@ def derivative(self, x) -> Array:
"""The derivative dy/dx of the forward transformation."""
# Following notation from the paper (eq. 5)
x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives
in_bounds = jnp.logical_and(x >= -self.interval, x <= self.interval)
in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1])
x_robust = jnp.where(in_bounds, x, 0) # To avoid nans
k = jnp.searchsorted(x_pos, x_robust) - 1
xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k])
Expand Down
4 changes: 2 additions & 2 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
Vmap,
)
from flowjax.distributions import AbstractDistribution, Transformed
from flowjax.wrappers import BijectionReparam, NonTrainable, WeightNormalization
from flowjax.wrappers import BijectionReparam, WeightNormalization, non_trainable


def _affine_with_min_scale(min_scale: float = 1e-2) -> Affine:
scale_reparam = Chain([SoftPlus(), NonTrainable(Loc(min_scale))])
scale_reparam = Chain([SoftPlus(), non_trainable(Loc(min_scale))])
return eqx.tree_at(
where=lambda aff: aff.scale,
pytree=Affine(),
Expand Down
2 changes: 1 addition & 1 deletion flowjax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _apply_inverse_and_check_valid(bijection, arr):
jnp.logical_and(jnp.isfinite(arr), ~jnp.isfinite(param_inv)),
"Non-finite value(s) introduced when reparameterizing. This suggests "
"the parameter vector passed to BijectionReparam was incompatible with "
f"the bijection used for reparmeterizing ({type(bijection).__name__}).",
f"the bijection used for reparameterizing ({type(bijection).__name__}).",
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.4.0"
version = "12.5.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
30 changes: 22 additions & 8 deletions tests/test_bijections/test_rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,41 @@
from jax.tree_util import tree_map

from flowjax.bijections import RationalQuadraticSpline
from flowjax.bijections.rational_quadratic_spline import _real_to_increasing_on_interval


def test_RationalQuadraticSpline_tails():
@pytest.mark.parametrize("interval", [3, (-4, 5)])
def test_RationalQuadraticSpline_tails(interval):
key = jr.PRNGKey(0)
x = jnp.array([-20, 0.1, 2, 20])
spline = RationalQuadraticSpline(knots=10, interval=3)
spline = RationalQuadraticSpline(knots=10, interval=interval)

# Change to random initialisation, rather than identity.
spline = tree_map(
lambda x: jr.normal(key, x.shape) if eqx.is_inexact_array(x) else x,
spline,
)

x = jr.uniform(key, (5,), minval=spline.interval[0], maxval=spline.interval[1])
y = vmap(spline.transform)(x)
expected_changed = jnp.array([True, False, False, True]) # identity padding
assert ((jnp.abs(y - x) <= 1e-5) == expected_changed).all()
assert pytest.approx(x, abs=1e-5) != y

# Outside interval, default to identity
x = jnp.array([spline.interval[0] - 1, spline.interval[1] + 1])
y = vmap(spline.transform)(x)
assert pytest.approx(x, abs=1e-5) == y


def test_RationalQuadraticSpline_init():
@pytest.mark.parametrize("interval", [3, (-4, 5)])
def test_RationalQuadraticSpline_init(interval):
# Test it is initialized at the identity
x = jnp.array([-1, 0.1, 2, 1])
spline = RationalQuadraticSpline(knots=10, interval=3)
x = jnp.array([-7, 0.1, 2, 1])
spline = RationalQuadraticSpline(knots=10, interval=interval)
y = vmap(spline.transform)(x)
assert pytest.approx(x, abs=1e-6) == y


def test_real_to_increasing_on_interval():
y = _real_to_increasing_on_interval(jnp.array([-3.0, -4, 5, 0, 0]), (-3, 7))
assert y.max() == 7
assert y.min() == -3
assert jnp.all(jnp.diff(y)) > 0

0 comments on commit bdae1ba

Please sign in to comment.