diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index c0ad4aded052..72cc52be0f93 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -2017,4 +2017,4 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) - if arrs[-1].ndim == 1: einsum_axes[-1] = einsum_axes[-1][:1] return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[arg-type, call-overload] - optimize='optimal', precision=precision) + optimize='auto', precision=precision)