diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index ca4e3ebaf6a2..cce8bb8e6f7f 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -317,7 +317,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given coefficients. - Example: + Examples: Scalar inputs: @@ -407,7 +407,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given coefficients. - Example: + Examples: >>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32) @@ -455,7 +455,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = jnp.array([2, 3]) >>> x2 = jnp.array([5, 4, 1]) >>> jnp.polyadd(x1, x2) @@ -637,7 +637,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = np.array([2, 1, 0]) >>> x2 = np.array([0, 5, 0, 3]) >>> np.polymul(x1, x2) @@ -702,7 +702,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> - :func:`jax.numpy.polysub`: Computes the difference of two polynomials. - :func:`jax.numpy.polymul`: Computes the product of two polynomials. - Example: + Examples: >>> x1 = jnp.array([5, 7, 9]) >>> x2 = jnp.array([4, 1]) >>> np.polydiv(x1, x2) @@ -755,7 +755,7 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = jnp.array([2, 3]) >>> x2 = jnp.array([5, 4, 1]) >>> jnp.polysub(x1, x2) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 857ed8668d59..f4c174f01afb 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -652,12 +652,26 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.add`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``+`` operator for + JAX arrays. Args: x, y: arrays to add. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise addition. + + Examples: + Calling ``add`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.add(x, 10) + Array([10, 11, 12, 13], dtype=int32) + + Calling ``add`` via the ``+`` operator: + + >>> x + 10 + Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) @@ -668,12 +682,26 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.multiply`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``*`` operator for + JAX arrays. Args: x, y: arrays to multiply. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise multiplication. + + Examples: + Calling ``multiply`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.multiply(x, 10) + Array([ 0, 10, 20, 30], dtype=int32) + + Calling ``multiply`` via the ``*`` operator: + + >>> x * 10 + Array([ 0, 10, 20, 30], dtype=int32) """ x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) @@ -684,12 +712,26 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``&`` operator for + JAX arrays. Args: x, y: integer or boolean arrays. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise bitwise AND. + + Examples: + Calling ``bitwise_and`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_and(x, 1) + Array([0, 1, 0, 1], dtype=int32) + + Calling ``bitwise_and`` via the ``&`` operator: + + >>> x & 1 + Array([0, 1, 0, 1], dtype=int32) """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) @@ -699,12 +741,26 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``|`` operator for + JAX arrays. Args: x, y: integer or boolean arrays. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise bitwise OR. + + Examples: + Calling ``bitwise_or`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_or(x, 1) + Array([1, 1, 3, 3], dtype=int32) + + Calling ``bitwise_and`` via the ``&`` operator: + + >>> x | 1 + Array([1, 1, 3, 3], dtype=int32) """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) @@ -714,12 +770,26 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``^`` operator for + JAX arrays. Args: x, y: integer or boolean arrays. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise bitwise XOR. + + Examples: + Calling ``bitwise_xor`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_xor(x, 1) + Array([1, 0, 3, 2], dtype=int32) + + Calling ``bitwise_xor`` via the ``^`` operator: + + >>> x ^ 1 + Array([1, 0, 3, 2], dtype=int32) """ return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) @@ -958,6 +1028,11 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: Returns: Array containing the result of the element-wise logical AND. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_and(x, 1) + Array([False, True, True, True], dtype=bool) """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) @@ -973,6 +1048,11 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: Returns: Array containing the result of the element-wise logical OR. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_or(x, 1) + Array([ True, True, True, True], dtype=bool) """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) @@ -988,6 +1068,11 @@ def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: Returns: Array containing the result of the element-wise logical XOR. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_xor(x, 1) + Array([ True, False, False, False], dtype=bool) """ return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) @@ -1373,7 +1458,7 @@ def rint(x: ArrayLike, /) -> Array: If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round to the nearest even integer. - Example: + Examples: >>> x1 = jnp.array([5, 4, 7]) >>> jnp.rint(x1) Array([5., 4., 7.], dtype=float32) diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index dc368367e14e..e7a0e2142327 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -215,48 +215,49 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None): Returns: Vectorized version of the given function. - Here are a few examples of how one could write vectorized linear algebra - routines using :func:`vectorize`: - - >>> from functools import partial - - >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') - ... def cross_product(a, b): - ... assert a.shape == b.shape and a.ndim == b.ndim == 1 - ... return jnp.array([a[1] * b[2] - a[2] * b[1], - ... a[2] * b[0] - a[0] * b[2], - ... a[0] * b[1] - a[1] * b[0]]) - - >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') - ... def matrix_vector_product(matrix, vector): - ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape - ... return matrix @ vector - - These functions are only written to handle 1D or 2D arrays (the ``assert`` - statements will never be violated), but with vectorize they support - arbitrary dimensional inputs with NumPy style broadcasting, e.g., - - >>> cross_product(jnp.ones(3), jnp.ones(3)).shape - (3,) - >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape - (2, 3) - >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape - (2, 2, 3) - >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ValueError: input with shape (3,) does not have enough dimensions for all - core dimensions ('n', 'k') on vectorized function with excluded=frozenset() - and signature='(n,k),(k)->(k)' - >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape - (2,) - >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape - (4, 2) - - Note that this has different semantics than `jnp.matmul`: - - >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4]. + Examples: + Here are a few examples of how one could write vectorized linear algebra + routines using :func:`vectorize`: + + >>> from functools import partial + + >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') + ... def cross_product(a, b): + ... assert a.shape == b.shape and a.ndim == b.ndim == 1 + ... return jnp.array([a[1] * b[2] - a[2] * b[1], + ... a[2] * b[0] - a[0] * b[2], + ... a[0] * b[1] - a[1] * b[0]]) + + >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') + ... def matrix_vector_product(matrix, vector): + ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape + ... return matrix @ vector + + These functions are only written to handle 1D or 2D arrays (the ``assert`` + statements will never be violated), but with vectorize they support + arbitrary dimensional inputs with NumPy style broadcasting, e.g., + + >>> cross_product(jnp.ones(3), jnp.ones(3)).shape + (3,) + >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape + (2, 3) + >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape + (2, 2, 3) + >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ValueError: input with shape (3,) does not have enough dimensions for all + core dimensions ('n', 'k') on vectorized function with excluded=frozenset() + and signature='(n,k),(k)->(k)' + >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape + (2,) + >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape + (4, 2) + + Note that this has different semantics than `jnp.matmul`: + + >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4]. """ if any(not isinstance(exclude, (str, int)) for exclude in excluded): raise TypeError("jax.numpy.vectorize can only exclude integer or string arguments, " diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a1d1a0292338..9dc2e079bb3f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6341,6 +6341,8 @@ def test_lax_numpy_docstrings(self): self.assertNotEmpty(doc) self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}") self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}") + if name not in ["frompyfunc", "isdtype", "promote_types"]: + self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}") @parameterized.named_parameters( {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])