Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: Improve docs for jax.numpy: float_power and nextafter #23536

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,14 +872,77 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.max(*promote_args("maximum", x, y))

@implements(np.float_power, module='numpy')

@partial(jit, inline=True)
def float_power(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Calculate element-wise base ``x`` exponential of ``y``.

JAX implementation of :obj:`numpy.float_power`.

Args:
x: scalar or array. Specifies the bases.
y: scalar or array. Specifies the exponents. ``x`` and ``y`` should either
have same shape or be broadcast compatible.

Returns:
An array containing the base ``x`` exponentials of ``y``, promoting to the
inexact dtype.

See also:
- :func:`jax.numpy.exp`: Calculates element-wise exponential of the input.
- :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
the input.

Examples:
Inputs with same shape:

>>> x = jnp.array([3, 1, -5])
>>> y = jnp.array([2, 4, -1])
>>> jnp.float_power(x, y)
Array([ 9. , 1. , -0.2], dtype=float32)

Inputs with broacast compatibility:

>>> x1 = jnp.array([[2, -4, 1],
... [-1, 2, 3]])
>>> y1 = jnp.array([-2, 1, 4])
>>> jnp.float_power(x1, y1)
Array([[ 0.25, -4. , 1. ],
[ 1. , 2. , 81. ]], dtype=float32)

``jnp.float_power`` produces ``nan`` for negative values raised to a non-integer
values.

>>> jnp.float_power(-3, 1.7)
Array(nan, dtype=float32, weak_type=True)
"""
return lax.pow(*promote_args_inexact("float_power", x, y))

@implements(np.nextafter, module='numpy')

@partial(jit, inline=True)
def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise next floating point value after ``x`` towards ``y``.

JAX implementation of :obj:`numpy.nextafter`.

Args:
x: scalar or array. Specifies the value after which the next number is found.
y: scalar or array. Specifies the direction towards which the next number is
found. ``x`` and ``y`` should either have same shape or be broadcast
compatible.

Returns:
An array containing the next representable number of ``x`` in the direction
of ``y``.

Examples:
>>> jnp.nextafter(2, 1) # doctest: +SKIP
Array(1.9999999, dtype=float32, weak_type=True)
>>> x = jnp.array([3, -2, 1])
>>> y = jnp.array([2, -1, 2])
>>> jnp.nextafter(x, y) # doctest: +SKIP
Array([ 2.9999998, -1.9999999, 1.0000001], dtype=float32)
"""
return lax.nextafter(*promote_args_inexact("nextafter", x, y))

# Logical ops
Expand Down