Skip to content

Commit

Permalink
Merge pull request #158 from danielward27/types
Browse files Browse the repository at this point in the history
Types
  • Loading branch information
danielward27 committed Apr 25, 2024
2 parents 78e143c + f3e26e5 commit e789f68
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
3 changes: 2 additions & 1 deletion flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Loc(AbstractBijection):
cond_shape: ClassVar[None] = None

def __init__(self, loc: ArrayLike):
self.loc = arraylike_to_array(loc)
self.loc = arraylike_to_array(loc, dtype=float)
self.shape = self.loc.shape

def transform(self, x, condition=None):
Expand Down Expand Up @@ -98,6 +98,7 @@ def __init__(
self,
scale: ArrayLike,
):
scale = arraylike_to_array(scale, "scale", dtype=float)
self.scale = wrappers.BijectionReparam(scale, SoftPlus())
self.shape = jnp.shape(wrappers.unwrap(scale))

Expand Down
9 changes: 5 additions & 4 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike
import numpy as np
from jaxtyping import Array, Int

from flowjax.bijections.bijection import AbstractBijection
from flowjax.utils import arraylike_to_array
Expand Down Expand Up @@ -63,11 +64,11 @@ class Permute(AbstractBijection):
permutation: tuple[Array, ...]
inverse_permutation: tuple[Array, ...]

def __init__(self, permutation: ArrayLike):
permutation = arraylike_to_array(permutation)
def __init__(self, permutation: Int[Array | np.ndarray, "..."]):
permutation = arraylike_to_array(permutation, dtype=int)
permutation = eqx.error_if(
permutation,
permutation.ravel().sort() != jnp.arange(permutation.size),
permutation.ravel().sort() != jnp.arange(permutation.size, dtype=int),
"Invalid permutation array provided.",
)
self.shape = permutation.shape
Expand Down
5 changes: 2 additions & 3 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ class Uniform(AbstractLocScaleDistribution):
bijection: Affine

def __init__(self, minval: ArrayLike, maxval: ArrayLike):
minval, maxval = arraylike_to_array(minval), arraylike_to_array(maxval)
shape = jnp.broadcast_shapes(minval.shape, maxval.shape)
shape = jnp.broadcast_shapes(jnp.shape(minval), jnp.shape(maxval))
minval, maxval = eqx.error_if(
(minval, maxval), maxval <= minval, "minval must be less than the maxval."
)
Expand Down Expand Up @@ -596,7 +595,7 @@ class _StandardStudentT(AbstractDistribution):
df: Array | AbstractUnwrappable[Array]

def __init__(self, df: ArrayLike):
df = arraylike_to_array(df)
df = arraylike_to_array(df, dtype=float)
df = eqx.error_if(df, df <= 0, "Degrees of freedom values must be positive.")
self.shape = jnp.shape(df)
self.df = BijectionReparam(df, SoftPlus())
Expand Down
5 changes: 4 additions & 1 deletion flowjax/experimental/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def condition(self):
return jax.lax.stop_gradient(self._condition)

def tree_flatten(self):
raise NotImplementedError()
return (self.bijection, self._condition, self.domain, self.codomain), (
("bijection", "_condition", "domain", "codomain"),
{},
)

def _argcheck_domains(self):
for k, v in {"domain": self.domain, "codomain": self.codomain}.items():
Expand Down

0 comments on commit e789f68

Please sign in to comment.