Skip to content

Commit

Permalink
jax.numpy: better docs for matmul-like functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 6, 2024
1 parent 3d3cb0b commit 6cb4045
Show file tree
Hide file tree
Showing 2 changed files with 315 additions and 35 deletions.
328 changes: 304 additions & 24 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3781,20 +3781,71 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,

### Tensor contraction operations


_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """
preferred_element_type : dtype, optional
If specified, accumulate results and return a result of the given data type.
If not specified, the accumulation dtype is determined from the type promotion
rules of the input array dtypes.
"""

@util.implements(np.dot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def dot(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Compute the dot product of two arrays.
JAX implementation of :func:`numpy.dot`.
This differs from :func:`jax.numpy.matmul` in two respects:
- if either ``a`` or ``b`` is a scalar, the result of ``dot`` is equivalent to
:func:`jax.numpy.multiply`, while the result of ``matmul`` is an error.
- if ``a`` and ``b`` have more than 2 dimensions, the batch indices are
stacked rather than broadcast.
Args:
a: first input array, of shape ``(..., N)``.
b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
In the multi-dimensional case, leading dimensions must be broadcast-compatible
with the leading dimensions of ``a``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the dot product of the inputs, with batch dimensions of
``a`` and ``b`` stacked rather than broadcast.
See also:
- :func:`jax.numpy.matmul`: broadcasted batched matmul.
- :func:`jax.lax.dot_general`: general batched matrix multiplication.
Examples:
For scalar inputs, ``dot`` computes the element-wise product:
>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)
For vector or matrix inputs, ``dot`` computes the vector or matrix product:
>>> M = jnp.array([[2, 3, 4],
... [5, 6, 7],
... [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51, 60, 29],
[ 96, 114, 62],
[ 61, 78, 95]], dtype=int32)
For higher-dimensional matrix products, batch dimensions are stacked, whereas
in :func:`~jax.numpy.matmul` they are broadcast. For example:
>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)
"""
util.check_arraylike("dot", a, b)
dtypes.check_user_dtype_supported(preferred_element_type, "dot")
a, b = asarray(a), asarray(b)
Expand Down Expand Up @@ -3822,14 +3873,64 @@ def dot(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)


@util.implements(np.matmul, module='numpy', lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def matmul(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
) -> Array:
"""Matrix Multiply."""
"""Perform a matrix multiplication.
JAX implementation of :func:`numpy.matmul`.
Args:
a: first input array, of shape ``(..., N)``.
b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``.
In the multi-dimensional case, leading dimensions must be broadcast-compatible
with the leading dimensions of ``a``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the matrix product of the inputs. Shape is ``a.shape[:-1]``
if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading
dimensions of ``a`` and ``b`` are broadcast together.
See Also:
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
- :func:`jax.numpy.linalg.tensordot`: batched tensor product.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.
Examples:
Vector dot products:
>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.matmul(a, b)
Array(32, dtype=int32)
Matrix dot product:
>>> a = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> b = jnp.array([[1, 2],
... [3, 4],
... [5, 6]])
>>> jnp.matmul(a, b)
Array([[22, 28],
[49, 64]], dtype=int32)
For convenience, in all cases you can do the same computation using
the ``@`` operator:
>>> a @ b
Array([[22, 28],
[49, 64]], dtype=int32)
"""
util.check_arraylike("matmul", a, b)
dtypes.check_user_dtype_supported(preferred_element_type, "matmul")
a, b = asarray(a), asarray(b)
Expand Down Expand Up @@ -3895,29 +3996,99 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)


@util.implements(np.vdot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def vdot(
a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
) -> Array:
"""Perform a conjugate multiplication of two 1D vectors.
JAX implementation of :func:`numpy.vdot`.
Args:
a: first input array, if not 1D it will be flattened.
b: second input array, if not 1D it will be flattened. Must have ``a.size == b.size``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
Scalar array (shape ``()``) containing the conjugate vector product of the inputs.
See Also:
- :func:`jax.numpy.vecdot`: batched vector product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.
Examples:
>>> x = jnp.array([1j, 2j, 3j])
>>> y = jnp.array([1., 2., 3.])
>>> jnp.vdot(x, y)
Array(0.-14.j, dtype=complex64)
Note the difference between this and :func:`~jax.numpy.dot`, which does not
conjugate the first input when complex:
>>> jnp.dot(x, y)
Array(0.+14.j, dtype=complex64)
"""
util.check_arraylike("vdot", a, b)
if issubdtype(_dtype(a), complexfloating):
a = ufuncs.conj(a)
return dot(ravel(a), ravel(b), precision=precision,
preferred_element_type=preferred_element_type)


@util.implements(
getattr(np, "vecdot", None), lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION,
# TODO(phawkins): numpy.vecdot doesn't have a __module__ attribute.
module="numpy")
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Perform a conjugate multiplication of two batched vectors.
JAX implementation of :func:`numpy.vecdot`.
Args:
a: left-hand side array.
b: right-hand side array. Size of ``b[axis]`` must match size of ``a[axis]``,
and remaining dimensions must be broadcast-compatible.
axis: axis along which to compute the dot product (default: -1)
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the conjugate dot product of ``a`` and ``b`` along ``axis``.
The non-contracted dimensions are broadcast together.
See Also:
- :func:`jax.numpy.vdot`: flattened vector product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.
Examples:
Vector conjugate-dot product of two 1D arrays:
>>> a = jnp.array([1j, 2j, 3j])
>>> b = jnp.array([4., 5., 6.])
>>> jnp.linalg.vecdot(a, b)
Array(0.-32.j, dtype=complex64)
Batched vector dot product of two 2D arrays:
>>> a = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> b = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(a, b, axis=-1)
Array([20, 47], dtype=int32)
"""
util.check_arraylike("jnp.vecdot", x1, x2)
x1_arr, x2_arr = asarray(x1), asarray(x2)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
Expand All @@ -3928,12 +4099,81 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
signature="(n),(n)->()")(x1_arr, x2_arr)


@util.implements(np.tensordot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
def tensordot(a: ArrayLike, b: ArrayLike,
axes: int | Sequence[int] | Sequence[Sequence[int]] = 2,
*, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
"""Compute the tensor dot product of two N-dimensional arrays.
JAX implementation of :func:`numpy.linalg.tensordot`.
Args:
a: N-dimensional array
b: M-dimensional array
axes: integer or tuple of sequences of integers. If an integer `k`, then
sum over the last `k` axes of ``a`` and the first `k` axes of ``b``,
in order. If a tuple, then ``axes[0]`` specifies the axes of ``a`` and
``axes[1]`` specifies the axes of ``b``.
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array containing the tensor dot product of the inputs
See also:
- :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions.
- :func:`jax.lax.dot_general`: XLA API for more general tensor contractions.
Examples:
>>> x1 = jnp.arange(24.).reshape(2, 3, 4)
>>> x2 = jnp.ones((3, 4, 5))
>>> jnp.tensordot(x1, x2)
Array([[ 66., 66., 66., 66., 66.],
[210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result when specifying the axes as explicit sequences:
>>> jnp.tensordot(x1, x2, axes=([1, 2], [0, 1]))
Array([[ 66., 66., 66., 66., 66.],
[210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result via :func:`~jax.numpy.einsum`:
>>> jnp.einsum('ijk,jkm->im', x1, x2)
Array([[ 66., 66., 66., 66., 66.],
[210., 210., 210., 210., 210.]], dtype=float32)
Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix
multiplication:
>>> x1 = jnp.array([[1, 2],
... [3, 4]])
>>> x2 = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.linalg.tensordot(x1, x2, axes=1)
Array([[ 9, 12, 15],
[19, 26, 33]], dtype=int32)
>>> x1 @ x2
Array([[ 9, 12, 15],
[19, 26, 33]], dtype=int32)
Setting ``axes=0`` for one-dimensional inputs is equivalent to
:func:`~jax.numpy.outer`:
>>> x1 = jnp.array([1, 2])
>>> x2 = jnp.array([1, 2, 3])
>>> jnp.linalg.tensordot(x1, x2, axes=0)
Array([[1, 2, 3],
[2, 4, 6]], dtype=int32)
>>> jnp.outer(x1, x2)
Array([[1, 2, 3],
[2, 4, 6]], dtype=int32)
"""
util.check_arraylike("tensordot", a, b)
dtypes.check_user_dtype_supported(preferred_element_type, "tensordot")
a, b = asarray(a), asarray(b)
Expand Down Expand Up @@ -4217,13 +4457,53 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)


@util.implements(np.inner, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def inner(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DType | None = None,
) -> Array:
"""Compute the inner product of two arrays.
JAX implementation of :func:`numpy.inner`.
Unlike :func:`jax.numpy.matmul` or :func:`jax.numpy.dot`, this always performs
a contraction along the last dimension of each input.
Args:
a: array of shape ``(..., N)``
b: array of shape ``(..., N)``
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
such values indicating precision of ``a`` and ``b``.
preferred_element_type: either ``None`` (default), which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
array of shape ``(*a.shape[:-1], *b.shape[:-1])`` containing the batched vector
product of the inputs.
See also:
- :func:`jax.numpy.vecdot`: conjugate multiplication along a specified axis.
- :func:`jax.numpy.tensordot`: general tensor multiplication.
- :func:`jax.numpy.matmul`: general batched matrix & vector multiplication.
Examples:
For 1D inputs, this implements standard (non-conjugate) vector multiplication:
>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)
For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:
>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)
"""
util.check_arraylike("inner", a, b)
if ndim(a) == 0 or ndim(b) == 0:
a = asarray(a, dtype=preferred_element_type)
Expand Down
Loading

0 comments on commit 6cb4045

Please sign in to comment.