Skip to content

Commit

Permalink
Test that jax.numpy docstrings include examples
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 21, 2024
1 parent d63afd8 commit 2b46795
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 49 deletions.
12 changes: 6 additions & 6 deletions jax/_src/numpy/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
Examples:
Scalar inputs:
Expand Down Expand Up @@ -407,7 +407,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
Examples:
>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)
Expand Down Expand Up @@ -455,7 +455,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
Examples:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polyadd(x1, x2)
Expand Down Expand Up @@ -637,7 +637,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
Examples:
>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
Expand Down Expand Up @@ -702,7 +702,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
Example:
Examples:
>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
Expand Down Expand Up @@ -755,7 +755,7 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
Examples:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polysub(x1, x2)
Expand Down
87 changes: 86 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,26 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.add`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``+`` operator for
JAX arrays.
Args:
x, y: arrays to add. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise addition.
Examples:
Calling ``add`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
Calling ``add`` via the ``+`` operator:
>>> x + 10
Array([10, 11, 12, 13], dtype=int32)
"""
x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
Expand All @@ -668,12 +682,26 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``*`` operator for
JAX arrays.
Args:
x, y: arrays to multiply. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise multiplication.
Examples:
Calling ``multiply`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.multiply(x, 10)
Array([ 0, 10, 20, 30], dtype=int32)
Calling ``multiply`` via the ``*`` operator:
>>> x * 10
Array([ 0, 10, 20, 30], dtype=int32)
"""
x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
Expand All @@ -684,12 +712,26 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``&`` operator for
JAX arrays.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise AND.
Examples:
Calling ``bitwise_and`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.bitwise_and(x, 1)
Array([0, 1, 0, 1], dtype=int32)
Calling ``bitwise_and`` via the ``&`` operator:
>>> x & 1
Array([0, 1, 0, 1], dtype=int32)
"""
return lax.bitwise_and(*promote_args("bitwise_and", x, y))

Expand All @@ -699,12 +741,26 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``|`` operator for
JAX arrays.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise OR.
Examples:
Calling ``bitwise_or`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.bitwise_or(x, 1)
Array([1, 1, 3, 3], dtype=int32)
Calling ``bitwise_and`` via the ``&`` operator:
>>> x | 1
Array([1, 1, 3, 3], dtype=int32)
"""
return lax.bitwise_or(*promote_args("bitwise_or", x, y))

Expand All @@ -714,12 +770,26 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``^`` operator for
JAX arrays.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise XOR.
Examples:
Calling ``bitwise_xor`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.bitwise_xor(x, 1)
Array([1, 0, 3, 2], dtype=int32)
Calling ``bitwise_xor`` via the ``^`` operator:
>>> x ^ 1
Array([1, 0, 3, 2], dtype=int32)
"""
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))

Expand Down Expand Up @@ -958,6 +1028,11 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
Returns:
Array containing the result of the element-wise logical AND.
Examples:
>>> x = jnp.arange(4)
>>> jnp.logical_and(x, 1)
Array([False, True, True, True], dtype=bool)
"""
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))

Expand All @@ -973,6 +1048,11 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
Returns:
Array containing the result of the element-wise logical OR.
Examples:
>>> x = jnp.arange(4)
>>> jnp.logical_or(x, 1)
Array([ True, True, True, True], dtype=bool)
"""
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))

Expand All @@ -988,6 +1068,11 @@ def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
Returns:
Array containing the result of the element-wise logical XOR.
Examples:
>>> x = jnp.arange(4)
>>> jnp.logical_xor(x, 1)
Array([ True, False, False, False], dtype=bool)
"""
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))

Expand Down Expand Up @@ -1373,7 +1458,7 @@ def rint(x: ArrayLike, /) -> Array:
If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round
to the nearest even integer.
Example:
Examples:
>>> x1 = jnp.array([5, 4, 7])
>>> jnp.rint(x1)
Array([5., 4., 7.], dtype=float32)
Expand Down
85 changes: 43 additions & 42 deletions jax/_src/numpy/vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,48 +215,49 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
Returns:
Vectorized version of the given function.
Here are a few examples of how one could write vectorized linear algebra
routines using :func:`vectorize`:
>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
... assert a.shape == b.shape and a.ndim == b.ndim == 1
... return jnp.array([a[1] * b[2] - a[2] * b[1],
... a[2] * b[0] - a[0] * b[2],
... a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
... return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the ``assert``
statements will never be violated), but with vectorize they support
arbitrary dimensional inputs with NumPy style broadcasting, e.g.,
>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)
Note that this has different semantics than `jnp.matmul`:
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
Examples:
Here are a few examples of how one could write vectorized linear algebra
routines using :func:`vectorize`:
>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
... assert a.shape == b.shape and a.ndim == b.ndim == 1
... return jnp.array([a[1] * b[2] - a[2] * b[1],
... a[2] * b[0] - a[0] * b[2],
... a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
... return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the ``assert``
statements will never be violated), but with vectorize they support
arbitrary dimensional inputs with NumPy style broadcasting, e.g.,
>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)
Note that this has different semantics than `jnp.matmul`:
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
"""
if any(not isinstance(exclude, (str, int)) for exclude in excluded):
raise TypeError("jax.numpy.vectorize can only exclude integer or string arguments, "
Expand Down
2 changes: 2 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6341,6 +6341,8 @@ def test_lax_numpy_docstrings(self):
self.assertNotEmpty(doc)
self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}")
self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}")
if name not in ["frompyfunc", "isdtype", "promote_types"]:
self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}")

@parameterized.named_parameters(
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])
Expand Down

0 comments on commit 2b46795

Please sign in to comment.