Skip to content

Commit

Permalink
Merge pull request #23536 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673011373
  • Loading branch information
jax authors committed Sep 10, 2024
2 parents a8b68c2 + ee04646 commit 6037dba
Showing 1 changed file with 65 additions and 2 deletions.
67 changes: 65 additions & 2 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,14 +872,77 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.max(*promote_args("maximum", x, y))

@implements(np.float_power, module='numpy')

@partial(jit, inline=True)
def float_power(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Calculate element-wise base ``x`` exponential of ``y``.
JAX implementation of :obj:`numpy.float_power`.
Args:
x: scalar or array. Specifies the bases.
y: scalar or array. Specifies the exponents. ``x`` and ``y`` should either
have same shape or be broadcast compatible.
Returns:
An array containing the base ``x`` exponentials of ``y``, promoting to the
inexact dtype.
See also:
- :func:`jax.numpy.exp`: Calculates element-wise exponential of the input.
- :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of
the input.
Examples:
Inputs with same shape:
>>> x = jnp.array([3, 1, -5])
>>> y = jnp.array([2, 4, -1])
>>> jnp.float_power(x, y)
Array([ 9. , 1. , -0.2], dtype=float32)
Inputs with broacast compatibility:
>>> x1 = jnp.array([[2, -4, 1],
... [-1, 2, 3]])
>>> y1 = jnp.array([-2, 1, 4])
>>> jnp.float_power(x1, y1)
Array([[ 0.25, -4. , 1. ],
[ 1. , 2. , 81. ]], dtype=float32)
``jnp.float_power`` produces ``nan`` for negative values raised to a non-integer
values.
>>> jnp.float_power(-3, 1.7)
Array(nan, dtype=float32, weak_type=True)
"""
return lax.pow(*promote_args_inexact("float_power", x, y))

@implements(np.nextafter, module='numpy')

@partial(jit, inline=True)
def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise next floating point value after ``x`` towards ``y``.
JAX implementation of :obj:`numpy.nextafter`.
Args:
x: scalar or array. Specifies the value after which the next number is found.
y: scalar or array. Specifies the direction towards which the next number is
found. ``x`` and ``y`` should either have same shape or be broadcast
compatible.
Returns:
An array containing the next representable number of ``x`` in the direction
of ``y``.
Examples:
>>> jnp.nextafter(2, 1) # doctest: +SKIP
Array(1.9999999, dtype=float32, weak_type=True)
>>> x = jnp.array([3, -2, 1])
>>> y = jnp.array([2, -1, 2])
>>> jnp.nextafter(x, y) # doctest: +SKIP
Array([ 2.9999998, -1.9999999, 1.0000001], dtype=float32)
"""
return lax.nextafter(*promote_args_inexact("nextafter", x, y))

# Logical ops
Expand Down

0 comments on commit 6037dba

Please sign in to comment.