Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.einsum_path: improve docs & annotations #21167

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4315,7 +4315,7 @@ def einsum(
package. Other options are ``True`` (same as ``"optimal"``), ``False``
(unoptimized), or any string supported by ``opt_einsum``, which
includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also
be a pre-computed path (see :func:`~jax.numpy.einsum_path`)
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Expand Down Expand Up @@ -4547,9 +4547,77 @@ def _default_poly_einsum_handler(*operands, **kwargs):
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
return contract_operands, contractions

@util.implements(np.einsum_path)
def einsum_path(subscripts, *operands, optimize='greedy'):
# using einsum_call=True here is an internal api for opt_einsum
@overload
def einsum_path(
subscripts: str, /,
*operands: ArrayLike,
optimize: bool | str | list[tuple[int, ...]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...

@overload
def einsum_path(
arr: ArrayLike,
axes: Sequence[Any], /,
*operands: ArrayLike | Sequence[Any],
optimize: bool | str | list[tuple[int, ...]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...

def einsum_path(
subscripts, /,
*operands,
optimize: bool | str | list[tuple[int, ...]] = 'auto'
) -> tuple[list[tuple[int, ...]], Any]:
"""Evaluates the optimal contraction path without evaluating the einsum.

JAX implementation of :func:`jax.numpy.einsum_path`. This function calls into
the opt_einsum_ package, and makes use of its optimization routines.

Args:
subscripts: string containing axes names separated by commas.
*operands: sequence of one or more arrays corresponding to the subscripts.
optimize: specify how to optimize the order of computation. In JAX this defaults
to ``"auto"``. Other options are ``True`` (same as ``"optimize"``), ``False``
(unoptimized), or any string supported by ``opt_einsum``, which
includes ``"optimize"``,, ``"greedy"``, ``"eager"``, and others.

Returns:
A tuple containing the path that may be passed to :func:`~jax.numpy.einsum`, and a
printable object representing this optimal path.

Example:
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3))
>>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100))
>>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5))
>>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal")
>>> print(path)
[(1, 2), (0, 1)]
>>> print(path_info)
Complete contraction: ij,jk,kl->il
Naive scaling: 4
Optimized scaling: 3
Naive FLOP count: 9.000e+3
Optimized FLOP count: 3.060e+3
Theoretical speedup: 2.941e+0
Largest intermediate: 1.500e+1 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 GEMM kl,jk->lj ij,lj->il
3 GEMM lj,ij->il il->il

Use the computed path in :func:`~jax.numpy.einsum`:

>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path)
Array([[-539, 216, 95, 592, 209],
[ 527, 76, 285, -436, -529]], dtype=int32)

.. _opt_einsum: https://github.com/dgasmith/opt_einsum
"""
if optimize is True:
optimize = 'optimal'
elif optimize is False:
optimize = Unoptimized()
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)

def _removechars(s, chars):
Expand Down
21 changes: 20 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,26 @@ def einsum(
_dot_general: Callable[..., Array] = ...,
) -> Array: ...

def einsum_path(subscripts, *operands, optimize = ...): ...
@overload
def einsum_path(
subscripts: str, /,
*operands: ArrayLike,
optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...
@overload
def einsum_path(
arr: ArrayLike,
axes: Sequence[Any], /,
*operands: Union[ArrayLike, Sequence[Any]],
optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...
@overload
def einsum_path(
subscripts, /,
*operands: ArrayLike,
optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...

def empty(shape: Any, dtype: Optional[DTypeLike] = ...,
device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ...
def empty_like(prototype: Union[ArrayLike, DuckTypedArray],
Expand Down
21 changes: 21 additions & 0 deletions tests/lax_numpy_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,27 @@ def test_einsum_optimization_modes(self, signature, shapes, optimize, dtype):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4)
self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4)

@jtu.sample_product(
[
{'signature': 'i->', 'shapes': [(3,)]},
{'signature': 'ii->i', 'shapes': [(4, 4)]},
{'signature': 'ij,jk', 'shapes': [(3, 4), (4, 3)]},
{'signature': 'ij,jkl,klm', 'shapes': [(2, 2), (2, 3, 4), (3, 4, 2)]},
],
optimize=[True, False, 'optimal', 'auto', 'greedy', 'eager'],
dtype=[np.dtype('float32')],
)
@jtu.skip_on_devices('tpu')
def test_einsum_path_optimization_modes(self, signature, shapes, optimize, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
def jnp_fun(*args, signature=signature, optimize=optimize):
path, _ = jnp.einsum_path(signature, *args, optimize=optimize)
return jnp.einsum(signature, *args, optimize=path)
np_fun = partial(np.einsum, signature)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4)
self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())