Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow uneven interval in spline #166

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions flowjax/bijections/rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def _real_to_increasing_on_interval(
arr: Float[Array, " dim"],
interval: float | int = 1,
interval: tuple[int | float, int | float],
softmax_adjust: float = 1e-2,
*,
pad_with_ends: bool = True,
Expand All @@ -35,10 +35,11 @@ def _real_to_increasing_on_interval(
widths = jax.nn.softmax(arr)
widths = (widths + softmax_adjust / widths.size) / (1 + softmax_adjust)
widths = widths.at[0].set(widths[0] / 2)
pos = 2 * interval * jnp.cumsum(widths) - interval
scale = interval[1] - interval[0]
pos = interval[0] + scale * jnp.cumsum(widths)

if pad_with_ends:
pos = jnp.pad(pos, pad_width=1, constant_values=(-interval, interval))
pos = jnp.pad(pos, pad_width=1, constant_values=interval)

return pos

Expand All @@ -48,7 +49,8 @@ class RationalQuadraticSpline(AbstractBijection):

Args:
knots: Number of knots.
interval: interval to transform, [-interval, interval].
interval: Interval to transform, if a scalar value, uses [-interval, interval],
if a tuple, uses [interval[0], interval[1]]
min_derivative: Minimum dervivative. Defaults to 1e-3.
softmax_adjust: Controls minimum bin width and height by rescaling softmax
output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced
Expand All @@ -57,7 +59,7 @@ class RationalQuadraticSpline(AbstractBijection):
"""

knots: int
interval: float | int
interval: tuple[int | float, int | float]
softmax_adjust: float | int
min_derivative: float
x_pos: Array | wrappers.AbstractUnwrappable[Array]
Expand All @@ -70,11 +72,12 @@ def __init__(
self,
*,
knots: int,
interval: float | int,
interval: float | int | tuple[int | float, int | float],
min_derivative: float = 1e-3,
softmax_adjust: float | int = 1e-2,
):
self.knots = knots
interval = interval if isinstance(interval, tuple) else (-interval, interval)
self.interval = interval
self.softmax_adjust = softmax_adjust
self.min_derivative = min_derivative
Expand All @@ -96,7 +99,7 @@ def __init__(
def transform(self, x, condition=None):
# Following notation from the paper
x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives
in_bounds = jnp.logical_and(x >= -self.interval, x <= self.interval)
in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1])
x_robust = jnp.where(in_bounds, x, 0) # To avoid nans
k = jnp.searchsorted(x_pos, x_robust) - 1 # k is bin number
xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k])
Expand All @@ -107,7 +110,7 @@ def transform(self, x, condition=None):
y = yk + num / den # eq. 4

# avoid numerical precision issues transforming from in -> out of bounds
y = jnp.clip(y, -self.interval, self.interval)
y = jnp.clip(y, self.interval[0], self.interval[1])
return jnp.where(in_bounds, y, x)

def transform_and_log_det(self, x, condition=None):
Expand All @@ -118,7 +121,7 @@ def transform_and_log_det(self, x, condition=None):
def inverse(self, y, condition=None):
# Following notation from the paper
x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives
in_bounds = jnp.logical_and(y >= -self.interval, y <= self.interval)
in_bounds = jnp.logical_and(y >= self.interval[0], y <= self.interval[1])
y_robust = jnp.where(in_bounds, y, 0) # To avoid nans
k = jnp.searchsorted(y_pos, y_robust) - 1
xk, xk1, yk, yk1 = x_pos[k], x_pos[k + 1], y_pos[k], y_pos[k + 1]
Expand All @@ -134,7 +137,7 @@ def inverse(self, y, condition=None):
x = xi * (xk1 - xk) + xk

# avoid numerical precision issues transforming from in -> out of bounds
x = jnp.clip(x, -self.interval, self.interval)
x = jnp.clip(x, self.interval[0], self.interval[1])
return jnp.where(in_bounds, x, y)

def inverse_and_log_det(self, y, condition=None):
Expand All @@ -146,7 +149,7 @@ def derivative(self, x) -> Array:
"""The derivative dy/dx of the forward transformation."""
# Following notation from the paper (eq. 5)
x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives
in_bounds = jnp.logical_and(x >= -self.interval, x <= self.interval)
in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1])
x_robust = jnp.where(in_bounds, x, 0) # To avoid nans
k = jnp.searchsorted(x_pos, x_robust) - 1
xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k])
Expand Down
4 changes: 2 additions & 2 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
Vmap,
)
from flowjax.distributions import AbstractDistribution, Transformed
from flowjax.wrappers import BijectionReparam, NonTrainable, WeightNormalization
from flowjax.wrappers import BijectionReparam, WeightNormalization, non_trainable


def _affine_with_min_scale(min_scale: float = 1e-2) -> Affine:
scale_reparam = Chain([SoftPlus(), NonTrainable(Loc(min_scale))])
scale_reparam = Chain([SoftPlus(), non_trainable(Loc(min_scale))])
return eqx.tree_at(
where=lambda aff: aff.scale,
pytree=Affine(),
Expand Down
2 changes: 1 addition & 1 deletion flowjax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _apply_inverse_and_check_valid(bijection, arr):
jnp.logical_and(jnp.isfinite(arr), ~jnp.isfinite(param_inv)),
"Non-finite value(s) introduced when reparameterizing. This suggests "
"the parameter vector passed to BijectionReparam was incompatible with "
f"the bijection used for reparmeterizing ({type(bijection).__name__}).",
f"the bijection used for reparameterizing ({type(bijection).__name__}).",
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.4.0"
version = "12.5.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
30 changes: 22 additions & 8 deletions tests/test_bijections/test_rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,41 @@
from jax.tree_util import tree_map

from flowjax.bijections import RationalQuadraticSpline
from flowjax.bijections.rational_quadratic_spline import _real_to_increasing_on_interval


def test_RationalQuadraticSpline_tails():
@pytest.mark.parametrize("interval", [3, (-4, 5)])
def test_RationalQuadraticSpline_tails(interval):
key = jr.PRNGKey(0)
x = jnp.array([-20, 0.1, 2, 20])
spline = RationalQuadraticSpline(knots=10, interval=3)
spline = RationalQuadraticSpline(knots=10, interval=interval)

# Change to random initialisation, rather than identity.
spline = tree_map(
lambda x: jr.normal(key, x.shape) if eqx.is_inexact_array(x) else x,
spline,
)

x = jr.uniform(key, (5,), minval=spline.interval[0], maxval=spline.interval[1])
y = vmap(spline.transform)(x)
expected_changed = jnp.array([True, False, False, True]) # identity padding
assert ((jnp.abs(y - x) <= 1e-5) == expected_changed).all()
assert pytest.approx(x, abs=1e-5) != y

# Outside interval, default to identity
x = jnp.array([spline.interval[0] - 1, spline.interval[1] + 1])
y = vmap(spline.transform)(x)
assert pytest.approx(x, abs=1e-5) == y


def test_RationalQuadraticSpline_init():
@pytest.mark.parametrize("interval", [3, (-4, 5)])
def test_RationalQuadraticSpline_init(interval):
# Test it is initialized at the identity
x = jnp.array([-1, 0.1, 2, 1])
spline = RationalQuadraticSpline(knots=10, interval=3)
x = jnp.array([-7, 0.1, 2, 1])
spline = RationalQuadraticSpline(knots=10, interval=interval)
y = vmap(spline.transform)(x)
assert pytest.approx(x, abs=1e-6) == y


def test_real_to_increasing_on_interval():
y = _real_to_increasing_on_interval(jnp.array([-3.0, -4, 5, 0, 0]), (-3, 7))
assert y.max() == 7
assert y.min() == -3
assert jnp.all(jnp.diff(y)) > 0
Loading