From 02a8a2bca7363bc35fc157609d9dab6a83c85955 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Mon, 29 Apr 2024 11:12:52 +0100 Subject: [PATCH 1/3] Add logistic distribution --- flowjax/bijections/affine.py | 3 +-- flowjax/distributions.py | 44 ++++++++++++++++++++++++++++++++---- tests/test_distributions.py | 4 ++++ 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index 2c8c5282..b676c025 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -140,7 +140,7 @@ class TriangularAffine(AbstractBijection): def __init__( self, - loc: Shaped[Array, " dim"], + loc: Shaped[ArrayLike, " #dim"], arr: Shaped[Array, "dim dim"], *, lower: bool = True, @@ -150,7 +150,6 @@ def __init__( arr.shape[0] != arr.shape[1] ): # TODO unnecersary if beartype enabled raise ValueError("arr must be a square, 2-dimensional matrix.") - dim = arr.shape[0] def _to_triangular(diag, arr): diff --git a/flowjax/distributions.py b/flowjax/distributions.py index 92e214c3..c8b1b3b1 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -13,7 +13,7 @@ from equinox import AbstractVar from jax.numpy import linalg from jax.scipy import stats as jstats -from jaxtyping import Array, ArrayLike, PRNGKeyArray +from jaxtyping import Array, ArrayLike, PRNGKeyArray, Shaped from flowjax.bijections import ( AbstractBijection, @@ -457,7 +457,11 @@ class MultivariateNormal(AbstractTransformed): base_dist: StandardNormal bijection: TriangularAffine - def __init__(self, loc: ArrayLike, covariance: ArrayLike): + def __init__( + self, + loc: Shaped[ArrayLike, "#dim"], + covariance: Shaped[Array, "dim dim"], + ): self.bijection = TriangularAffine(loc, linalg.cholesky(covariance)) self.base_dist = StandardNormal(self.bijection.shape) @@ -685,10 +689,42 @@ class Exponential(AbstractTransformed): base_dist: _StandardExponential bijection: Scale - def __init__(self, rate: Array): - self.base_dist = _StandardExponential(rate.shape) + def __init__(self, rate: ArrayLike): + self.base_dist = _StandardExponential(jnp.shape(rate)) self.bijection = Scale(1 / rate) @property def rate(self): return 1 / unwrap(self.bijection.scale) + + +class _StandardLogistic(AbstractDistribution): + shape: tuple[int, ...] = () + cond_shape: ClassVar[None] = None + + def _sample(self, key, condition=None): + return jr.logistic(key, self.shape) + + def _log_prob(self, x, condition=None): + return jstats.logistic.logpdf(x).sum() + + +class Logistic(AbstractLocScaleDistribution): + """Logistic distribution. + + ``loc`` and ``scale`` should broadcast to the shape of the distribution. + + Args: + loc: Means. Defaults to 0. + scale: Standard deviations. Defaults to 1. + + """ + + base_dist: _StandardLogistic + bijection: Affine + + def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1): + self.base_dist = _StandardLogistic( + shape=jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)), + ) + self.bijection = Affine(loc=loc, scale=scale) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 677b80ea..7cf335aa 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -15,6 +15,7 @@ Exponential, Gumbel, Laplace, + Logistic, LogNormal, MultivariateNormal, Normal, @@ -25,6 +26,7 @@ _StandardCauchy, _StandardGumbel, _StandardLaplace, + _StandardLogistic, _StandardStudentT, _StandardUniform, ) @@ -50,6 +52,8 @@ "_StandardLaplace": _StandardLaplace, "Laplace": lambda shape: Laplace(jnp.ones(shape)), "Exponential": lambda shape: Exponential(jnp.ones(shape)), + "_StandardLogistic": _StandardLogistic, + "Logistic": lambda shape: Logistic(jnp.ones(shape)), } From ff6781d9f845d5978270b88b909f798f70b0162d Mon Sep 17 00:00:00 2001 From: danielward27 Date: Mon, 29 Apr 2024 11:14:30 +0100 Subject: [PATCH 2/3] rm space --- flowjax/distributions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flowjax/distributions.py b/flowjax/distributions.py index c8b1b3b1..620c815a 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -717,7 +717,6 @@ class Logistic(AbstractLocScaleDistribution): Args: loc: Means. Defaults to 0. scale: Standard deviations. Defaults to 1. - """ base_dist: _StandardLogistic From a750f8da6829bc2df3939c1fefd040d9cc13d4b5 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Mon, 29 Apr 2024 11:25:16 +0100 Subject: [PATCH 3/3] Tidy docs and default rate for exponential --- flowjax/distributions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flowjax/distributions.py b/flowjax/distributions.py index 620c815a..e525a994 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -409,7 +409,7 @@ class Normal(AbstractLocScaleDistribution): ``loc`` and ``scale`` should broadcast to the desired shape of the distribution. Args: - loc: Means. Defaults to 0. + loc: Means. Defaults to 0. Defaults to 0. scale: Standard deviations. Defaults to 1. """ @@ -541,7 +541,7 @@ class Gumbel(AbstractLocScaleDistribution): ``loc`` and ``scale`` should broadcast to the dimension of the distribution. Args: - loc: Location paramter. + loc: Location paramter. Defaults to 0. scale: Scale parameter. Defaults to 1. """ @@ -577,7 +577,7 @@ class Cauchy(AbstractLocScaleDistribution): ``loc`` and ``scale`` should broadcast to the dimension of the distribution. Args: - loc: Location paramter. + loc: Location paramter. Defaults to 0. scale: Scale parameter. Defaults to 1. """ @@ -689,7 +689,7 @@ class Exponential(AbstractTransformed): base_dist: _StandardExponential bijection: Scale - def __init__(self, rate: ArrayLike): + def __init__(self, rate: ArrayLike = 1): self.base_dist = _StandardExponential(jnp.shape(rate)) self.bijection = Scale(1 / rate) @@ -715,8 +715,8 @@ class Logistic(AbstractLocScaleDistribution): ``loc`` and ``scale`` should broadcast to the shape of the distribution. Args: - loc: Means. Defaults to 0. - scale: Standard deviations. Defaults to 1. + loc: Location parameter. Defaults to 0. + scale: Scale parameter. Defaults to 1. """ base_dist: _StandardLogistic