From 192f4f45abec77cefdb36c9600d83232f44724d5 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Thu, 25 Apr 2024 15:30:07 +0100 Subject: [PATCH 1/3] Ensure right dtype params --- flowjax/bijections/affine.py | 3 ++- flowjax/bijections/utils.py | 9 +++++---- flowjax/distributions.py | 4 +++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index c23b572b..2c8c5282 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -67,7 +67,7 @@ class Loc(AbstractBijection): cond_shape: ClassVar[None] = None def __init__(self, loc: ArrayLike): - self.loc = arraylike_to_array(loc) + self.loc = arraylike_to_array(loc, dtype=float) self.shape = self.loc.shape def transform(self, x, condition=None): @@ -98,6 +98,7 @@ def __init__( self, scale: ArrayLike, ): + scale = arraylike_to_array(scale, "scale", dtype=float) self.scale = wrappers.BijectionReparam(scale, SoftPlus()) self.shape = jnp.shape(wrappers.unwrap(scale)) diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 62e395d0..cb7e5fb7 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -6,7 +6,8 @@ import equinox as eqx import jax.numpy as jnp -from jaxtyping import Array, ArrayLike +import numpy as np +from jaxtyping import Array, Int from flowjax.bijections.bijection import AbstractBijection from flowjax.utils import arraylike_to_array @@ -63,11 +64,11 @@ class Permute(AbstractBijection): permutation: tuple[Array, ...] inverse_permutation: tuple[Array, ...] - def __init__(self, permutation: ArrayLike): - permutation = arraylike_to_array(permutation) + def __init__(self, permutation: Int[Array | np.ndarray, "..."]): + permutation = arraylike_to_array(permutation, dtype=int) permutation = eqx.error_if( permutation, - permutation.ravel().sort() != jnp.arange(permutation.size), + permutation.ravel().sort() != jnp.arange(permutation.size, dtype=int), "Invalid permutation array provided.", ) self.shape = permutation.shape diff --git a/flowjax/distributions.py b/flowjax/distributions.py index a2a5af78..2f03eb6e 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -500,7 +500,9 @@ class Uniform(AbstractLocScaleDistribution): bijection: Affine def __init__(self, minval: ArrayLike, maxval: ArrayLike): - minval, maxval = arraylike_to_array(minval), arraylike_to_array(maxval) + minval, maxval = ( + arraylike_to_array(arr, dtype=float) for arr in (minval, maxval) + ) shape = jnp.broadcast_shapes(minval.shape, maxval.shape) minval, maxval = eqx.error_if( (minval, maxval), maxval <= minval, "minval must be less than the maxval." From 5514b4748f85768c3048e3615264ecdca62cf661 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Thu, 25 Apr 2024 15:30:22 +0100 Subject: [PATCH 2/3] Add tree flatten for numypro bijection --- flowjax/experimental/numpyro.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flowjax/experimental/numpyro.py b/flowjax/experimental/numpyro.py index 20ac2d5d..465cf400 100644 --- a/flowjax/experimental/numpyro.py +++ b/flowjax/experimental/numpyro.py @@ -209,7 +209,10 @@ def condition(self): return jax.lax.stop_gradient(self._condition) def tree_flatten(self): - raise NotImplementedError() + return (self.bijection, self._condition, self.domain, self.codomain), ( + ("bijection", "_condition", "domain", "codomain"), + {}, + ) def _argcheck_domains(self): for k, v in {"domain": self.domain, "codomain": self.codomain}.items(): From f3e26e55eac9b07587b7e9218d82b4be74e75ca8 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Thu, 25 Apr 2024 15:49:12 +0100 Subject: [PATCH 3/3] Leave uniform type check to affine --- flowjax/distributions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/flowjax/distributions.py b/flowjax/distributions.py index 2f03eb6e..92e214c3 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -500,10 +500,7 @@ class Uniform(AbstractLocScaleDistribution): bijection: Affine def __init__(self, minval: ArrayLike, maxval: ArrayLike): - minval, maxval = ( - arraylike_to_array(arr, dtype=float) for arr in (minval, maxval) - ) - shape = jnp.broadcast_shapes(minval.shape, maxval.shape) + shape = jnp.broadcast_shapes(jnp.shape(minval), jnp.shape(maxval)) minval, maxval = eqx.error_if( (minval, maxval), maxval <= minval, "minval must be less than the maxval." ) @@ -598,7 +595,7 @@ class _StandardStudentT(AbstractDistribution): df: Array | AbstractUnwrappable[Array] def __init__(self, df: ArrayLike): - df = arraylike_to_array(df) + df = arraylike_to_array(df, dtype=float) df = eqx.error_if(df, df <= 0, "Degrees of freedom values must be positive.") self.shape = jnp.shape(df) self.df = BijectionReparam(df, SoftPlus())