diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 72cc52be0f93..45d888e4e1f6 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -35,7 +35,7 @@ from jax._src.numpy import reductions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike from jax._src.util import canonicalize_axis -from jax._src.typing import ArrayLike, Array +from jax._src.typing import ArrayLike, Array, DTypeLike class EighResult(NamedTuple): @@ -1612,7 +1612,9 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa return norm(x, axis=axis, keepdims=keepdims, ord=ord) -def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: +def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: """Compute the (batched) vector conjugate dot product of two arrays. JAX implementation of :func:`numpy.linalg.vecdot`. @@ -1622,6 +1624,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``, and remaining dimensions must be broadcast-compatible. axis: axis along which to compute the dot product (default: -1) + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``x1`` and ``x2``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. Returns: array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``. @@ -1649,10 +1658,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: Array([20, 47], dtype=int32) """ check_arraylike('jnp.linalg.vecdot', x1, x2) - return jnp.vecdot(x1, x2, axis=axis) + return jnp.vecdot(x1, x2, axis=axis, precision=precision, + preferred_element_type=preferred_element_type) -def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: +def matmul(x1: ArrayLike, x2: ArrayLike, /, *, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: """Perform a matrix multiplication. JAX implementation of :func:`numpy.linalg.matmul`. @@ -1662,6 +1674,13 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of ``x1``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``x1`` and ``x2``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. Returns: array containing the matrix product of the inputs. Shape is ``x1.shape[:-1]`` @@ -1699,11 +1718,14 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: [49, 64]], dtype=int32) """ check_arraylike('jnp.linalg.matmul', x1, x2) - return jnp.matmul(x1, x2) + return jnp.matmul(x1, x2, precision=precision, + preferred_element_type=preferred_element_type) def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, - axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array: + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: """Compute the tensor dot product of two N-dimensional arrays. JAX implementation of :func:`numpy.linalg.tensordot`. @@ -1715,6 +1737,13 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, sum over the last `k` axes of ``x1`` and the first `k` axes of ``x2``, in order. If a tuple, then ``axes[0]`` specifies the axes of ``x1`` and ``axes[1]`` specifies the axes of ``x2``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``x1`` and ``x2``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. Returns: array containing the tensor dot product of the inputs @@ -1770,7 +1799,8 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, [2, 4, 6]], dtype=int32) """ check_arraylike('jnp.linalg.tensordot', x1, x2) - return jnp.tensordot(x1, x2, axes=axes) + return jnp.tensordot(x1, x2, axes=axes, precision=precision, + preferred_element_type=preferred_element_type) def svdvals(x: ArrayLike, /) -> Array: diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 22bc6c62b20d..8d21c46d45bd 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -697,6 +697,12 @@ def testVecdot(self, lhs_shape, rhs_shape, axis, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + # smoke-test for optional kwargs. + jnp_fn = partial(jnp.linalg.vecdot, axis=axis, + precision=lax.Precision.HIGHEST, + preferred_element_type=dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + # jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here. @jtu.sample_product( [ @@ -719,6 +725,12 @@ def testMatmul(self, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + # smoke-test for optional kwargs. + jnp_fn = partial(jnp.linalg.matmul, + precision=lax.Precision.HIGHEST, + preferred_element_type=dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + # jnp.linalg.tensordot is an alias of jnp.tensordot; do a minimal test here. @jtu.sample_product( [ @@ -742,6 +754,12 @@ def testTensordot(self, lhs_shape, rhs_shape, axes, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + # smoke-test for optional kwargs. + jnp_fn = partial(jnp.linalg.tensordot, axes=axes, + precision=lax.Precision.HIGHEST, + preferred_element_type=dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + @jtu.sample_product( [ dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)