Skip to content

Commit

Permalink
jnp.delete: better docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 3, 2024
1 parent e70191b commit dbeec8c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
72 changes: 61 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,24 +3581,74 @@ def append(
return concatenate([arr, values], axis=axis)


@util.implements(np.delete,
lax_description=_dedent("""
delete() usually requires the index specification to be static. If the index
is an integer array that is guaranteed to contain unique entries, you may
specify ``assume_unique_indices=True`` to perform the operation in a
manner that does not require static indices."""),
extra_params=_dedent("""
assume_unique_indices : int, optional (default=False)
In case of array-like integer (not boolean) indices, assume the indices are unique,
and perform the deletion in a way that is compatible with JIT and other JAX
transformations."""))
def delete(
arr: ArrayLike,
obj: ArrayLike | slice,
axis: int | None = None,
*,
assume_unique_indices: bool = False,
) -> Array:
"""Delete entry or entries from an array.
JAX implementation of :func:`numpy.delete`.
Args:
arr: array from which entries will be deleted.
obj: index, indices, or slice to be deleted.
axis: axis along which entries will be deleted.
assume_unique_indices: In case of array-like integer (not boolean) indices,
assume the indices are unique, and perform the deletion in a way that is
compatible with JIT and other JAX transformations.
Returns:
Copy of ``arr`` with specified indices deleted.
Note:
``delete()`` usually requires the index specification to be static. If the
index is an integer array that is guaranteed to contain unique entries, you
may specify ``assume_unique_indices=True`` to perform the operation in a
manner that does not require static indices.
Examples:
Delete entries from a 1D array:
>>> a = jnp.array([4, 5, 6, 7, 8, 9])
>>> jnp.delete(a, 2)
Array([4, 5, 7, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(1, 4)) # delete a[1:4]
Array([4, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(None, None, 2)) # delete a[::2]
Array([5, 7, 9], dtype=int32)
Delete entries from a 2D array along a specified axis:
>>> a2 = jnp.array([[4, 5, 6],
... [7, 8, 9]])
>>> jnp.delete(a2, 1, axis=1)
Array([[4, 6],
[7, 9]], dtype=int32)
Delete multiple entries via a sequence of indices:
>>> indices = jnp.array([0, 1, 3])
>>> jnp.delete(a, indices)
Array([6, 8, 9], dtype=int32)
This will fail under :func:`~jax.jit` and other transformations, because
the output shape cannot be known with the possibility of duplicate indices:
>>> jax.jit(jnp.delete)(a, indices) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].
If you can ensure that the indices are unique, pass ``assume_unique_indices``
to allow this to be executed under JIT:
>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
>>> jit_delete(a, indices, assume_unique_indices=True)
Array([6, 8, 9], dtype=int32)
"""
util.check_arraylike("delete", arr)
if axis is None:
arr = ravel(arr)
Expand Down
1 change: 1 addition & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6020,6 +6020,7 @@ def test_lax_numpy_docstrings(self):
# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {
'argwhere',
'delete',
'flatnonzero',
'fromfile',
'fromiter',
Expand Down

0 comments on commit dbeec8c

Please sign in to comment.