Skip to content

Commit

Permalink
Merge pull request #118 from danielward27/docs
Browse files Browse the repository at this point in the history
Docs
  • Loading branch information
danielward27 committed Nov 16, 2023
2 parents 282e45c + 9dcb0d4 commit 7d1fa6e
Show file tree
Hide file tree
Showing 34 changed files with 362 additions and 350 deletions.
1 change: 1 addition & 0 deletions docs/api/bijections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Bijections
:members:
:undoc-members:
:show-inheritance:
:member-order: groupwise
1 change: 0 additions & 1 deletion docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ Loss functions from ``flowjax.train.losses``.
.. automodule:: flowjax.train.losses
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/api/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ corresponding conditioning variables if appropriate), we can use ``fit_to_data``
Alternatively, we can use ``fit_to_variational_target`` to fit the flow to a function
using variational inference.

.. autofunction:: flowjax.train.fit_to_variational_target
.. autofunction:: flowjax.train.fit_to_variational_target
24 changes: 19 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Configuration file for the Sphinx documentation builder."""

import builtins
import sys
from pathlib import Path

builtins.GENERATING_DOCUMENTATION = True # For processing ArrayLike

import jax # noqa Required to avoid circular import

sys.path.insert(0, Path("..").resolve())
Expand All @@ -22,17 +26,23 @@
extensions = [
"sphinx.ext.viewcode",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"nbsphinx",
"sphinx_copybutton",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
]

intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
}

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

add_module_names = False
napoleon_include_init_with_doc = False
# napoleon_include_init_with_doc = False

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
Expand All @@ -50,8 +60,12 @@
}

pygments_style = "xcode"
autodoc_typehints = "none"
autodoc_member_order = "bysource"

copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True

napolean_use_rtype = False
napoleon_attr_annotations = True

autodoc_type_aliases = {"ArrayLike": "ArrayLike"}
add_module_names = False
2 changes: 1 addition & 1 deletion docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ FAQ
Freezing parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Often it is useful to not train particular parameters. To achieve this we can provide a
``filter_spec`` to :py:func:`~flowjax.train.data_fit.fit_to_data`. For example, to avoid
``filter_spec`` to :func:`~flowjax.train.fit_to_data`. For example, to avoid
training the base distribution, we could create a ``filter_spec`` as follows

.. testsetup::
Expand Down
1 change: 0 additions & 1 deletion flowjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""flowjax - Basic flowjax implementation in jax."""

from importlib.metadata import version

__version__ = version("flowjax")
Expand Down
13 changes: 13 additions & 0 deletions flowjax/_custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# We do this for now due to an incompatibility between equinox abstract class
# extensions and the documentation generator sphinx
# https://github.com/patrick-kidger/equinox/issues/591. This will likely be fixable with
# https://peps.python.org/pep-0649/ in python 3.13
import builtins

if getattr(builtins, "GENERATING_DOCUMENTATION", False):

class ArrayLike:
pass

else:
from jaxtyping import ArrayLike # noqa: F401
43 changes: 21 additions & 22 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Affine bijections."""
from __future__ import annotations

from collections.abc import Callable
from typing import ClassVar
Expand All @@ -20,12 +21,11 @@ class Affine(AbstractBijection):
``loc`` and ``scale`` should broadcast to the desired shape of the bijection.
Args:
loc (ArrayLike): Location parameter. Defaults to 0.
scale (ArrayLike): Scale parameter. Defaults to 1.
positivity_constraint (AbstractBijection | None): Bijection with shape
matching the Affine bijection, that maps the scale parameter from an
unbounded domain to the positive domain. Defaults to
:class:`~flowjax.bijections.SoftPlus`.
loc: Location parameter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
positivity_constraint: Bijection with shape matching the Affine bijection, that
maps the scale parameter from an unbounded domain to the positive domain.
Defaults to :class:`~flowjax.bijections.SoftPlus`.
"""

shape: tuple[int, ...]
Expand Down Expand Up @@ -77,17 +77,16 @@ class TriangularAffine(AbstractBijection):
triangular matrix, and :math:`b` is the bias vector.
Args:
loc (ArrayLike): Location parameter. If this is scalar, it is broadcast to the
dimension inferred from arr.
arr (ArrayLike): Triangular matrix.
lower (bool): Whether the mask should select the lower or upper
triangular matrix (other elements ignored). Defaults to True (lower).
weight_normalisation (bool): If true, carry out weight normalisation.
positivity_constraint (AbstractBijection): Bijection with shape matching the
dimension of the triangular affine bijection, that maps the diagonal
entries of the array from an unbounded domain to the positive domain.
Also used for weight normalisation parameters, if used. Defaults to
SoftPlus.
loc: Location parameter. If this is scalar, it is broadcast to the dimension
inferred from arr.
arr: Triangular matrix.
lower: Whether the mask should select the lower or upper
triangular matrix (other elements ignored). Defaults to True (lower).
weight_normalisation: If true, carry out weight normalisation.
positivity_constraint: Bijection with shape matching the dimension of the
triangular affine bijection, that maps the diagonal entries of the array
from an unbounded domain to the positive domain. Also used for weight
normalisation parameters, if used. Defaults to SoftPlus.
"""
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
Expand Down Expand Up @@ -181,11 +180,11 @@ class AdditiveCondition(AbstractBijection):
module with trainable parameters.
Args:
module (Callable[[ArrayLike], ArrayLike]): A callable (e.g. a function or
callable module) that maps array with shape cond_shape, to a shape
that is broadcastable with the shape of the bijection.
shape (tuple[int, ...]): The shape of the bijection.
cond_shape (tuple[int, ...]): The condition shape of the bijection.
module: A callable (e.g. a function or callable module) that maps array with
shape cond_shape, to a shape that is broadcastable with the shape of the
bijection.
shape: The shape of the bijection.
cond_shape: The condition shape of the bijection.
Example:
Conditioning using a linear transformation
Expand Down
29 changes: 13 additions & 16 deletions flowjax/bijections/bijection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
bijection can be used to invert the orientation if a fast inverse is desired (e.g.
maximum likelihood fitting of flows).
"""

import functools
from abc import abstractmethod

import equinox as eqx
from jax import Array
from jax.typing import ArrayLike

from flowjax._custom_types import ArrayLike
from flowjax.utils import arraylike_to_array


Expand Down Expand Up @@ -98,10 +97,10 @@ def transform(self, x: ArrayLike, condition: ArrayLike | None = None) -> Array:
"""Apply the forward transformation.
Args:
x (ArrayLike): Input with shape matching bijections.shape.
condition (ArrayLike | None, optional): Condition, with shape matching
bijection.cond_shape, required for conditional bijections. Defaults to
None.
x: Input with shape matching ``bijections.shape``.
condition: Condition, with shape matching ``bijection.cond_shape``, required
for conditional bijections and ignored for unconditional bijections.
Defaults to None.
"""

@abstractmethod
Expand All @@ -113,19 +112,18 @@ def transform_and_log_det(
"""Apply transformation and compute the log absolute Jacobian determinant.
Args:
x (ArrayLike): Input with shape matching the bijections shape
condition (ArrayLike | None, optional): . Defaults to None.
x: Input with shape matching the bijections shape
condition: . Defaults to None.
"""

@abstractmethod
def inverse(self, y: ArrayLike, condition: ArrayLike | None = None) -> Array:
"""Compute the inverse transformation.
Args:
y (ArrayLike): Input array with shape matching bijection.shape
condition (ArrayLike | None, optional): Condition array with shape matching
bijection.cond_shape. Required for conditional bijections. Defaults to
None.
y: Input array with shape matching bijection.shape
condition: Condition array with shape matching bijection.cond_shape.
Required for conditional bijections. Defaults to None.
"""

@abstractmethod
Expand All @@ -137,8 +135,7 @@ def inverse_and_log_det(
"""Inverse transformation and corresponding log absolute jacobian determinant.
Args:
y (ArrayLike): Input array with shape matching bijection.shape.
condition (ArrayLike | None, optional): Condition array with shape matching
bijection.cond_shape. Required for conditional bijections. Defaults to
None.
y: Input array with shape matching bijection.shape.
condition: Condition array with shape matching bijection.cond_shape.
Required for conditional bijections. Defaults to None.
"""
19 changes: 9 additions & 10 deletions flowjax/bijections/block_autoregressive_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,15 @@ class BlockAutoregressiveNetwork(AbstractBijection):
densities (see https://github.com/danielward27/flowjax/issues/102).
Args:
key (KeyArray): Jax PRNGKey
dim (int): Dimension of the distribution.
cond_dim (tuple[int, ...] | None): Dimension of conditioning variables.
depth (int): Number of hidden layers in the network.
block_dim (int): Block dimension (hidden layer size is `dim*block_dim`).
activation: (Bijection | Callable | None). Activation function, either
a scalar bijection or a callable that computes the activation for a
scalar value. Note that the activation should be bijective
to ensure invertibility of the network and in general should map
real -> real to ensure that when transforming a distribution (either
key: Jax PRNGKey
dim: Dimension of the distribution.
cond_dim: Dimension of conditioning variables.
depth: Number of hidden layers in the network.
block_dim: Block dimension (hidden layer size is `dim*block_dim`).
activation: Activation function, either a scalar bijection or a callable that
computes the activation for a scalar value. Note that the activation should
be bijective to ensure invertibility of the network and in general should
map real -> real to ensure that when transforming a distribution (either
with the forward or inverse), the map is defined across the support of
the base distribution. Defaults to ``LeakyTanh(3)``.
"""
Expand Down
4 changes: 2 additions & 2 deletions flowjax/bijections/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class Chain(AbstractBijection):
"""Chain together arbitrary bijections to form another bijection.
Args:
bijections (Sequence[Bijection]): Sequence of bijections. The bijection
shapes must match, and any none None condition shapes must match.
bijections: Sequence of bijections. The bijection shapes must match, and any
none None condition shapes must match.
"""

shape: tuple[int, ...]
Expand Down
9 changes: 4 additions & 5 deletions flowjax/bijections/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ class Concatenate(AbstractBijection):
See also :class:`Stack`.
Args:
bijections (Sequence[Bijection]): Bijections, to stack into a single
bijection.
axis (int): Axis along which to stack. Defaults to 0.
bijections: Bijections, to stack into a single bijection.
axis: Axis along which to stack. Defaults to 0.
"""

shape: tuple[int, ...]
Expand Down Expand Up @@ -94,8 +93,8 @@ class Stack(AbstractBijection):
See also :class:`Concatenate`.
Args:
bijections (list[Bijection]): Bijections.
axis (int): Axis along which to stack. Defaults to 0.
bijections: Bijections.
axis: Axis along which to stack. Defaults to 0.
"""

shape: tuple[int, ...]
Expand Down
20 changes: 9 additions & 11 deletions flowjax/bijections/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@ class Coupling(AbstractBijection):
"""Coupling layer implementation (https://arxiv.org/abs/1605.08803).
Args:
key (KeyArray): Jax PRNGKey
transformer (AbstractBijection): Unconditional bijection with shape ()
to be parameterised by the conditioner neural netork.
untransformed_dim (int): Number of untransformed conditioning variables
(e.g. dim // 2).
dim (int): Total dimension.
cond_dim (int | None): Dimension of additional conditioning variables.
nn_width (int): Neural network hidden layer width.
nn_depth (int): Neural network hidden layer size.
nn_activation (Callable): Neural network activation function.
Defaults to jnn.relu.
key: Jax PRNGKey
transformer: Unconditional bijection with shape () to be parameterised by the
conditioner neural netork.
untransformed_dim: Number of untransformed conditioning variables (e.g. dim//2).
dim: Total dimension.
cond_dim: Dimension of additional conditioning variables.
nn_width: Neural network hidden layer width.
nn_depth: Neural network hidden layer size.
nn_activation: Neural network activation function. Defaults to jnn.relu.
"""

shape: tuple[int, ...]
Expand Down
3 changes: 1 addition & 2 deletions flowjax/bijections/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class Exp(AbstractBijection):
"""Elementwise exponential transform (forward) and log transform (inverse).
Args:
shape (tuple[int, ...] | None): Shape of the bijection.
Defaults to None.
shape: Shape of the bijection. Defaults to ().
"""

shape: tuple[int, ...] = ()
Expand Down
26 changes: 13 additions & 13 deletions flowjax/bijections/jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class Scan(AbstractBijection):
to construct these using ``equinox.filter_vmap``.
Args:
bijection (AbstractBijection): A bijection, in which the arrays leaves have
an additional leading axis to scan over. It is often can convenient to
create compatible bijections with ``equinox.filter_vmap``.
bijection: A bijection, in which the arrays leaves have an additional leading
axis to scan over. It is often can convenient to create compatible
bijections with ``equinox.filter_vmap``.
Example:
Below is equivilent to ``Chain([Affine(p) for p in params])``.
Expand Down Expand Up @@ -92,16 +92,16 @@ class Vmap(AbstractBijection):
"""Applies vmap to bijection methods to add a batch dimension to the bijection.
Args:
bijection (AbstractBijection): The bijection to vectorize.
in_axis (int | None | Callable): Specify which axes of the bijection
parameters to vectorise over. It should be a PyTree of ``None``, ``int``
with the tree structure being a prefix of the bijection, or a callable
mapping ``Leaf -> Union[None, int]``. Defaults to None.
axis_size (int, optional): The size of the new axis. This should be left
unspecified if in_axis is provided, as the size can be inferred from the
bijection parameters. Defaults to None.
in_axis_condition (int | None, optional): Optionally define an axis of
the conditioning variable to vectorize over. Defaults to None.
bijection: The bijection to vectorize.
in_axis: Specify which axes of the bijection parameters to vectorise over. It
should be a PyTree of ``None``, ``int`` with the tree structure being a
prefix of the bijection, or a callable mapping ``Leaf -> Union[None, int]``.
Defaults to None.
axis_size: The size of the new axis. This should be left unspecified if in_axis
is provided, as the size can be inferred from the bijection parameters.
Defaults to None.
in_axis_condition: Optionally define an axis of the conditioning variable to
vectorize over. Defaults to None.
Example:
The two most common use cases, are shown below:
Expand Down
16 changes: 8 additions & 8 deletions flowjax/bijections/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class MaskedAutoregressive(AbstractBijection):
- https://arxiv.org/abs/1705.07057v4
Args:
key (KeyArray): Jax PRNGKey
transformer (AbstractBijection): Bijection with shape () to be parameterised
by the autoregressive network.
dim (int): Dimension.
cond_dim (int | None): Dimension of any conditioning variables.
nn_width (int): Neural network width.
nn_depth (int): Neural network depth.
nn_activation (Callable): Neural network activation. Defaults to jnn.relu.
key: Jax PRNGKey
transformer: Bijection with shape () to be parameterised by the autoregressive
network.
dim: Dimension.
cond_dim: Dimension of any conditioning variables.
nn_width: Neural network width.
nn_depth: Neural network depth.
nn_activation: Neural network activation. Defaults to jnn.relu.
"""

shape: tuple[int, ...]
Expand Down
Loading

0 comments on commit 7d1fa6e

Please sign in to comment.