diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 20c7c900f1b3..4707a1de8449 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4315,7 +4315,7 @@ def einsum( package. Other options are ``True`` (same as ``"optimal"``), ``False`` (unoptimized), or any string supported by ``opt_einsum``, which includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also - be a pre-computed path (see :func:`~jax.numpy.einsum_path`) + be a pre-computed path (see :func:`~jax.numpy.einsum_path`). precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). @@ -4547,9 +4547,77 @@ def _default_poly_einsum_handler(*operands, **kwargs): contract_operands = [operands[mapping[id(d)]] for d in out_dummies] return contract_operands, contractions -@util.implements(np.einsum_path) -def einsum_path(subscripts, *operands, optimize='greedy'): - # using einsum_call=True here is an internal api for opt_einsum +@overload +def einsum_path( + subscripts: str, /, + *operands: ArrayLike, + optimize: bool | str | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... + +@overload +def einsum_path( + arr: ArrayLike, + axes: Sequence[Any], /, + *operands: ArrayLike | Sequence[Any], + optimize: bool | str | list[tuple[int, ...]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... + +def einsum_path( + subscripts, /, + *operands, + optimize: bool | str | list[tuple[int, ...]] = 'auto' + ) -> tuple[list[tuple[int, ...]], Any]: + """Evaluates the optimal contraction path without evaluating the einsum. + + JAX implementation of :func:`jax.numpy.einsum_path`. This function calls into + the opt_einsum_ package, and makes use of its optimization routines. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays corresponding to the subscripts. + optimize: specify how to optimize the order of computation. In JAX this defaults + to ``"auto"``. Other options are ``True`` (same as ``"optimize"``), ``False`` + (unoptimized), or any string supported by ``opt_einsum``, which + includes ``"optimize"``,, ``"greedy"``, ``"eager"``, and others. + + Returns: + A tuple containing the path that may be passed to :func:`~jax.numpy.einsum`, and a + printable object representing this optimal path. + + Example: + >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) + >>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3)) + >>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100)) + >>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5)) + >>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal") + >>> print(path) + [(1, 2), (0, 1)] + >>> print(path_info) + Complete contraction: ij,jk,kl->il + Naive scaling: 4 + Optimized scaling: 3 + Naive FLOP count: 9.000e+3 + Optimized FLOP count: 3.060e+3 + Theoretical speedup: 2.941e+0 + Largest intermediate: 1.500e+1 elements + -------------------------------------------------------------------------------- + scaling BLAS current remaining + -------------------------------------------------------------------------------- + 3 GEMM kl,jk->lj ij,lj->il + 3 GEMM lj,ij->il il->il + + Use the computed path in :func:`~jax.numpy.einsum`: + + >>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path) + Array([[-539, 216, 95, 592, 209], + [ 527, 76, 285, -436, -529]], dtype=int32) + + .. _opt_einsum: https://github.com/dgasmith/opt_einsum + """ + if optimize is True: + optimize = 'optimal' + elif optimize is False: + optimize = Unoptimized() return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) def _removechars(s, chars): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index a5973be6435c..ef1ffd2bba40 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -317,7 +317,26 @@ def einsum( _dot_general: Callable[..., Array] = ..., ) -> Array: ... -def einsum_path(subscripts, *operands, optimize = ...): ... +@overload +def einsum_path( + subscripts: str, /, + *operands: ArrayLike, + optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... +@overload +def einsum_path( + arr: ArrayLike, + axes: Sequence[Any], /, + *operands: Union[ArrayLike, Sequence[Any]], + optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... +@overload +def einsum_path( + subscripts, /, + *operands: ArrayLike, + optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., +) -> tuple[list[tuple[int, ...]], Any]: ... + def empty(shape: Any, dtype: Optional[DTypeLike] = ..., device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def empty_like(prototype: Union[ArrayLike, DuckTypedArray], diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index b7a3df68ad22..7397cf3e4ee8 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -411,6 +411,27 @@ def test_einsum_optimization_modes(self, signature, shapes, optimize, dtype): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4) self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4) + @jtu.sample_product( + [ + {'signature': 'i->', 'shapes': [(3,)]}, + {'signature': 'ii->i', 'shapes': [(4, 4)]}, + {'signature': 'ij,jk', 'shapes': [(3, 4), (4, 3)]}, + {'signature': 'ij,jkl,klm', 'shapes': [(2, 2), (2, 3, 4), (3, 4, 2)]}, + ], + optimize=[True, False, 'optimal', 'auto', 'greedy', 'eager'], + dtype=[np.dtype('float32')], + ) + @jtu.skip_on_devices('tpu') + def test_einsum_path_optimization_modes(self, signature, shapes, optimize, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype) for shape in shapes] + def jnp_fun(*args, signature=signature, optimize=optimize): + path, _ = jnp.einsum_path(signature, *args, optimize=optimize) + return jnp.einsum(signature, *args, optimize=path) + np_fun = partial(np.einsum, signature) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=1E-4) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1E-4) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())