Skip to content

Commit

Permalink
Merge pull request #168 from danielward27/docs
Browse files Browse the repository at this point in the history
Update docs
  • Loading branch information
danielward27 committed Jul 24, 2024
2 parents eeb3481 + 27fc50c commit 91b1c39
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Loss functions from ``flowjax.train.losses``.
.. automodule:: flowjax.train.losses
:members:
:undoc-members:
:special-members: __call__
13 changes: 7 additions & 6 deletions flowjax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class AbstractUnwrappable(eqx.Module, Generic[T]):
stop_gradient before accessing the parameters.
If ``_dummy`` is set to an array (must have shape ()), this is used for inferring
vmapped dimensions (and sizes) when calling ``unwrap`` to automatically vecotorize
the method. In some cases this is important for supporting the case where an
``AbstractUnwrappable`` is created within e.g. ``eqx.filter_vmap``.
vmapped dimensions (and sizes) when calling :func:`unwrap` to automatically
vecotorize the method. In some cases this is important for supporting the case where
an :class:`AbstractUnwrappable` is created within e.g. ``eqx.filter_vmap``.
"""

_dummy: eqx.AbstractVar[Int[Scalar, ""] | None]
Expand Down Expand Up @@ -97,8 +97,9 @@ def unwrap(self) -> T:
class NonTrainable(AbstractUnwrappable[T]):
"""Applies stop gradient to all arraylike leaves before unwrapping.
See also ``non_trainable``, which is probably a generally prefereable way to achieve
similar behaviour, which wraps the arraylike leaves directly, rather than the tree.
See also :func:`non_trainable`, which is probably a generally prefereable way to
achieve similar behaviour, which wraps the arraylike leaves directly, rather than
the tree.
Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also
filter out NonTrainable nodes when partitioning parameters for training, or when
Expand All @@ -114,7 +115,7 @@ def unwrap(self) -> T:


def non_trainable(tree: PyTree):
"""Freezes parameters by wrapping inexact array leaves with ``NonTrainable``.
"""Freezes parameters by wrapping inexact array leaves with :class:``NonTrainable``.
Wrapping the arrays rather than the entire tree is often preferable, allowing easier
access to attributes, compared to wrapping the entire tree.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "13.0.0"
version = "13.0.1"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down

0 comments on commit 91b1c39

Please sign in to comment.