Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non_trainable #165

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@ FAQ
Freezing parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Often it is useful to not train particular parameters. The easiest way to achieve this
is to use the :class:`flowjax.wrappers.NonTrainable` wrapper class. For example, to
avoid training the base distribution of a transformed distribution:

.. testsetup::

from flowjax.distributions import Normal
flow = Normal()
is to use :func:`flowjax.wrappers.non_trainable`. This will wrap the inexact array
leaves with :class:`flowjax.wrappers.NonTrainable`, which will apply ``stop_gradient``
when unwrapping the parameters. For commonly used distribution and bijection methods,
unwrapping is applied automatically. For example

.. doctest::

>>> from flowjax.distributions import Normal
>>> from flowjax.wrappers import non_trainable
>>> dist = Normal()
>>> dist = non_trainable(dist)

>>> import equinox as eqx
>>> from flowjax.wrappers import NonTrainable
>>> flow = eqx.tree_at(lambda flow: flow.base_dist, flow, replace_fn=NonTrainable)
To mark part of a tree as frozen, use ``non_trainable`` with e.g.
``equinox.tree_at`` or ``jax.tree_map``.

If you wish to avoid training e.g. a specific type, it may be easier to use
``jax.tree_map`` to apply the NonTrainable wrapper as required.

Standardizing variables
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
25 changes: 24 additions & 1 deletion flowjax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ def unwrap(self) -> T:
class NonTrainable(AbstractUnwrappable[T]):
"""Applies stop gradient to all arraylike leaves before unwrapping.

See also ``non_trainable``, which is probably a generally prefereable way to achieve
similar behaviour, which wraps the arraylike leaves directly, rather than the tree.

Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also
filter out these modules when partitioning parameters for training, or when
filter out NonTrainable nodes when partitioning parameters for training, or when
parameterizing bijections in coupling/masked autoregressive flows (transformers).
"""

Expand All @@ -110,6 +113,26 @@ def unwrap(self) -> T:
return eqx.combine(lax.stop_gradient(differentiable), static)


def non_trainable(tree: PyTree):
"""Freezes parameters by wrapping inexact array leaves with ``NonTrainable``.

Wrapping the arrays rather than the entire tree is often preferable, allowing easier
access to attributes, compared to wrapping the entire tree.

Args:
tree: The pytree.
"""

def _map_fn(leaf):
return NonTrainable(leaf) if eqx.is_inexact_array(leaf) else leaf

return jax.tree_util.tree_map(
f=_map_fn,
tree=tree,
is_leaf=lambda x: isinstance(x, NonTrainable),
)


def _apply_inverse_and_check_valid(bijection, arr):
param_inv = bijection._vectorize.inverse(arr)
return eqx.error_if(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.3.0"
version = "12.4.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
12 changes: 7 additions & 5 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Lambda,
NonTrainable,
WeightNormalization,
non_trainable,
unwrap,
)

Expand Down Expand Up @@ -56,15 +57,16 @@ def test_Lambda():
assert pytest.approx(unwrap(unwrappable)) == jnp.zeros((3, 2))


def test_NonTrainable():
dist = Normal()
dist = eqx.tree_at(lambda dist: dist.bijection, dist, replace_fn=NonTrainable)
def test_NonTrainable_and_non_trainable():
dist1 = eqx.tree_at(lambda dist: dist.bijection, Normal(), replace_fn=NonTrainable)
dist2 = non_trainable(Normal())

def loss(dist, x):
return dist.log_prob(x)

grad = eqx.filter_grad(loss)(dist, 1)
assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0]
for dist in [dist1, dist2]:
grad = eqx.filter_grad(loss)(dist, 1)
assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0]


def test_WeightNormalization():
Expand Down
Loading