Skip to content

Commit

Permalink
Merge pull request #21106 from jakevdp:linalg-precision
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632217396
  • Loading branch information
jax authors committed May 9, 2024
2 parents 0c4d81c + 2ddb7ff commit 1a7a2aa
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
44 changes: 37 additions & 7 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array
from jax._src.typing import ArrayLike, Array, DTypeLike


class EighResult(NamedTuple):
Expand Down Expand Up @@ -1612,7 +1612,9 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
return norm(x, axis=axis, keepdims=keepdims, ord=ord)


def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Compute the (batched) vector conjugate dot product of two arrays.
JAX implementation of :func:`numpy.linalg.vecdot`.
Expand All @@ -1622,6 +1624,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``,
and remaining dimensions must be broadcast-compatible.
axis: axis along which to compute the dot product (default: -1)
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``x1`` and ``x2``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``.
Expand Down Expand Up @@ -1649,10 +1658,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
Array([20, 47], dtype=int32)
"""
check_arraylike('jnp.linalg.vecdot', x1, x2)
return jnp.vecdot(x1, x2, axis=axis)
return jnp.vecdot(x1, x2, axis=axis, precision=precision,
preferred_element_type=preferred_element_type)


def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Perform a matrix multiplication.
JAX implementation of :func:`numpy.linalg.matmul`.
Expand All @@ -1662,6 +1674,13 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
In the multi-dimensional case, leading dimensions must be broadcast-compatible
with the leading dimensions of ``x1``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``x1`` and ``x2``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the matrix product of the inputs. Shape is ``x1.shape[:-1]``
Expand Down Expand Up @@ -1699,11 +1718,14 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array:
[49, 64]], dtype=int32)
"""
check_arraylike('jnp.linalg.matmul', x1, x2)
return jnp.matmul(x1, x2)
return jnp.matmul(x1, x2, precision=precision,
preferred_element_type=preferred_element_type)


def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array:
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Compute the tensor dot product of two N-dimensional arrays.
JAX implementation of :func:`numpy.linalg.tensordot`.
Expand All @@ -1715,6 +1737,13 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
sum over the last `k` axes of ``x1`` and the first `k` axes of ``x2``,
in order. If a tuple, then ``axes[0]`` specifies the axes of ``x1`` and
``axes[1]`` specifies the axes of ``x2``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``x1`` and ``x2``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the tensor dot product of the inputs
Expand Down Expand Up @@ -1770,7 +1799,8 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
[2, 4, 6]], dtype=int32)
"""
check_arraylike('jnp.linalg.tensordot', x1, x2)
return jnp.tensordot(x1, x2, axes=axes)
return jnp.tensordot(x1, x2, axes=axes, precision=precision,
preferred_element_type=preferred_element_type)


def svdvals(x: ArrayLike, /) -> Array:
Expand Down
18 changes: 18 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,12 @@ def testVecdot(self, lhs_shape, rhs_shape, axis, dtype):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

# smoke-test for optional kwargs.
jnp_fn = partial(jnp.linalg.vecdot, axis=axis,
precision=lax.Precision.HIGHEST,
preferred_element_type=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)

# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
@jtu.sample_product(
[
Expand All @@ -719,6 +725,12 @@ def testMatmul(self, lhs_shape, rhs_shape, dtype):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

# smoke-test for optional kwargs.
jnp_fn = partial(jnp.linalg.matmul,
precision=lax.Precision.HIGHEST,
preferred_element_type=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)

# jnp.linalg.tensordot is an alias of jnp.tensordot; do a minimal test here.
@jtu.sample_product(
[
Expand All @@ -742,6 +754,12 @@ def testTensordot(self, lhs_shape, rhs_shape, axes, dtype):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

# smoke-test for optional kwargs.
jnp_fn = partial(jnp.linalg.tensordot, axes=axes,
precision=lax.Precision.HIGHEST,
preferred_element_type=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
[
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
Expand Down

0 comments on commit 1a7a2aa

Please sign in to comment.