Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy: better docs for matmul-like functions #21092

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 304 additions & 24 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3811,20 +3811,71 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,

### Tensor contraction operations


_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """
preferred_element_type : dtype, optional
If specified, accumulate results and return a result of the given data type.
If not specified, the accumulation dtype is determined from the type promotion
rules of the input array dtypes.
"""

@util.implements(np.dot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def dot(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Compute the dot product of two arrays.

JAX implementation of :func:`numpy.dot`.

This differs from :func:`jax.numpy.matmul` in two respects:

- if either ``a`` or ``b`` is a scalar, the result of ``dot`` is equivalent to
:func:`jax.numpy.multiply`, while the result of ``matmul`` is an error.
- if ``a`` and ``b`` have more than 2 dimensions, the batch indices are
stacked rather than broadcast.

Args:
a: first input array, of shape ``(..., N)``.
b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
In the multi-dimensional case, leading dimensions must be broadcast-compatible
with the leading dimensions of ``a``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.

Returns:
array containing the dot product of the inputs, with batch dimensions of
``a`` and ``b`` stacked rather than broadcast.

See also:
- :func:`jax.numpy.matmul`: broadcasted batched matmul.
- :func:`jax.lax.dot_general`: general batched matrix multiplication.

Examples:
For scalar inputs, ``dot`` computes the element-wise product:

>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)

For vector or matrix inputs, ``dot`` computes the vector or matrix product:

>>> M = jnp.array([[2, 3, 4],
... [5, 6, 7],
... [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51, 60, 29],
[ 96, 114, 62],
[ 61, 78, 95]], dtype=int32)

For higher-dimensional matrix products, batch dimensions are stacked, whereas
in :func:`~jax.numpy.matmul` they are broadcast. For example:

>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)
"""
util.check_arraylike("dot", a, b)
dtypes.check_user_dtype_supported(preferred_element_type, "dot")
a, b = asarray(a), asarray(b)
Expand Down Expand Up @@ -3852,14 +3903,64 @@ def dot(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)


@util.implements(np.matmul, module='numpy', lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def matmul(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
) -> Array:
"""Matrix Multiply."""
"""Perform a matrix multiplication.

JAX implementation of :func:`numpy.matmul`.

Args:
a: first input array, of shape ``(..., N)``.
b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
In the multi-dimensional case, leading dimensions must be broadcast-compatible
with the leading dimensions of ``a``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.

Returns:
array containing the matrix product of the inputs. Shape is ``a.shape[:-1]``
if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading
dimensions of ``a`` and ``b`` are broadcast together.

See Also:
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
- :func:`jax.numpy.linalg.tensordot`: batched tensor product.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.

Examples:
Vector dot products:

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.matmul(a, b)
Array(32, dtype=int32)

Matrix dot product:

>>> a = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> b = jnp.array([[1, 2],
... [3, 4],
... [5, 6]])
>>> jnp.matmul(a, b)
Array([[22, 28],
[49, 64]], dtype=int32)

For convenience, in all cases you can do the same computation using
the ``@`` operator:

>>> a @ b
Array([[22, 28],
[49, 64]], dtype=int32)
"""
util.check_arraylike("matmul", a, b)
dtypes.check_user_dtype_supported(preferred_element_type, "matmul")
a, b = asarray(a), asarray(b)
Expand Down Expand Up @@ -3925,29 +4026,99 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)


@util.implements(np.vdot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def vdot(
a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
) -> Array:
"""Perform a conjugate multiplication of two 1D vectors.

JAX implementation of :func:`numpy.vdot`.

Args:
a: first input array, if not 1D it will be flattened.
b: second input array, if not 1D it will be flattened. Must have ``a.size == b.size``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.

Returns:
Scalar array (shape ``()``) containing the conjugate vector product of the inputs.

See Also:
- :func:`jax.numpy.vecdot`: batched vector product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.

Examples:
>>> x = jnp.array([1j, 2j, 3j])
>>> y = jnp.array([1., 2., 3.])
>>> jnp.vdot(x, y)
Array(0.-14.j, dtype=complex64)

Note the difference between this and :func:`~jax.numpy.dot`, which does not
conjugate the first input when complex:

>>> jnp.dot(x, y)
Array(0.+14.j, dtype=complex64)
"""
util.check_arraylike("vdot", a, b)
if issubdtype(_dtype(a), complexfloating):
a = ufuncs.conj(a)
return dot(ravel(a), ravel(b), precision=precision,
preferred_element_type=preferred_element_type)


@util.implements(
getattr(np, "vecdot", None), lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION,
# TODO(phawkins): numpy.vecdot doesn't have a __module__ attribute.
module="numpy")
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Perform a conjugate multiplication of two batched vectors.

JAX implementation of :func:`numpy.vecdot`.

Args:
a: left-hand side array.
b: right-hand side array. Size of ``b[axis]`` must match size of ``a[axis]``,
and remaining dimensions must be broadcast-compatible.
axis: axis along which to compute the dot product (default: -1)
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.

Returns:
array containing the conjugate dot product of ``a`` and ``b`` along ``axis``.
The non-contracted dimensions are broadcast together.

See Also:
- :func:`jax.numpy.vdot`: flattened vector product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.

Examples:
Vector conjugate-dot product of two 1D arrays:

>>> a = jnp.array([1j, 2j, 3j])
>>> b = jnp.array([4., 5., 6.])
>>> jnp.linalg.vecdot(a, b)
Array(0.-32.j, dtype=complex64)

Batched vector dot product of two 2D arrays:

>>> a = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> b = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(a, b, axis=-1)
Array([20, 47], dtype=int32)
"""
util.check_arraylike("jnp.vecdot", x1, x2)
x1_arr, x2_arr = asarray(x1), asarray(x2)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
Expand All @@ -3958,12 +4129,81 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
signature="(n),(n)->()")(x1_arr, x2_arr)


@util.implements(np.tensordot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
def tensordot(a: ArrayLike, b: ArrayLike,
axes: int | Sequence[int] | Sequence[Sequence[int]] = 2,
*, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Compute the tensor dot product of two N-dimensional arrays.

JAX implementation of :func:`numpy.linalg.tensordot`.

Args:
a: N-dimensional array
b: M-dimensional array
axes: integer or tuple of sequences of integers. If an integer `k`, then
sum over the last `k` axes of ``a`` and the first `k` axes of ``b``,
in order. If a tuple, then ``axes[0]`` specifies the axes of ``a`` and
``axes[1]`` specifies the axes of ``b``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.

Returns:
array containing the tensor dot product of the inputs

See also:
- :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions.
- :func:`jax.lax.dot_general`: XLA API for more general tensor contractions.

Examples:
>>> x1 = jnp.arange(24.).reshape(2, 3, 4)
>>> x2 = jnp.ones((3, 4, 5))
>>> jnp.tensordot(x1, x2)
Array([[ 66., 66., 66., 66., 66.],
[210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result when specifying the axes as explicit sequences:

>>> jnp.tensordot(x1, x2, axes=([1, 2], [0, 1]))
Array([[ 66., 66., 66., 66., 66.],
[210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result via :func:`~jax.numpy.einsum`:

>>> jnp.einsum('ijk,jkm->im', x1, x2)
Array([[ 66., 66., 66., 66., 66.],
[210., 210., 210., 210., 210.]], dtype=float32)

Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix
multiplication:

>>> x1 = jnp.array([[1, 2],
... [3, 4]])
>>> x2 = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.linalg.tensordot(x1, x2, axes=1)
Array([[ 9, 12, 15],
[19, 26, 33]], dtype=int32)
>>> x1 @ x2
Array([[ 9, 12, 15],
[19, 26, 33]], dtype=int32)

Setting ``axes=0`` for one-dimensional inputs is equivalent to
:func:`~jax.numpy.outer`:

>>> x1 = jnp.array([1, 2])
>>> x2 = jnp.array([1, 2, 3])
>>> jnp.linalg.tensordot(x1, x2, axes=0)
Array([[1, 2, 3],
[2, 4, 6]], dtype=int32)
>>> jnp.outer(x1, x2)
Array([[1, 2, 3],
[2, 4, 6]], dtype=int32)
"""
util.check_arraylike("tensordot", a, b)
dtypes.check_user_dtype_supported(preferred_element_type, "tensordot")
a, b = asarray(a), asarray(b)
Expand Down Expand Up @@ -4247,13 +4487,53 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)


@util.implements(np.inner, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def inner(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DType | None = None,
) -> Array:
"""Compute the inner product of two arrays.

JAX implementation of :func:`numpy.inner`.

Unlike :func:`jax.numpy.matmul` or :func:`jax.numpy.dot`, this always performs
a contraction along the last dimension of each input.

Args:
a: array of shape ``(..., N)``
b: array of shape ``(..., N)``
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.

Returns:
array of shape ``(*a.shape[:-1], *b.shape[:-1])`` containing the batched vector
product of the inputs.

See also:
- :func:`jax.numpy.vecdot`: conjugate multiplication along a specified axis.
- :func:`jax.numpy.tensordot`: general tensor multiplication.
- :func:`jax.numpy.matmul`: general batched matrix & vector multiplication.

Examples:
For 1D inputs, this implements standard (non-conjugate) vector multiplication:

>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)

For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:

>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)
"""
util.check_arraylike("inner", a, b)
if ndim(a) == 0 or ndim(b) == 0:
a = asarray(a, dtype=preferred_element_type)
Expand Down
Loading