Skip to content

Commit

Permalink
Improve docs for jnp.stack & related APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 5, 2024
1 parent 97db78b commit 6bb4528
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 17 deletions.
209 changes: 193 additions & 16 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2537,30 +2537,207 @@ def _split(op: str, ary: ArrayLike,
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
for start, end in zip(split_indices[:-1], split_indices[1:])]

@util.implements(np.split, lax_description=_ARRAY_VIEW_DOC)

def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
axis: int = 0) -> list[Array]:
"""Split an array into sub-arrays.
JAX implementation of :func:`numpy.split`.
Args:
ary: N-dimensional array-like object to split
indices_or_sections: either a single integer or a sequence of indices.
- if ``indices_or_sections`` is an integer *N*, then *N* must evenly divide
``ary.shape[axis]`` and ``ary`` will be divided into *N* equally-sized
chunks along ``axis``.
- if ``indices_or_sections`` is a sequence of integers, then these integers
specify the boundary between unevenly-sized chunks along ``axis``; see
examples below.
axis: the axis along which to split; defaults to 0.
Returns:
A list of arrays. If ``indices_or_sections`` is an integer *N*, then the list is
of length *N*. If ``indices_or_sections`` is a sequence *seq*, then the list is
is of length *len(seq) + 1*.
Examples:
Splitting a 1-dimensional array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
Split into three equal sections:
>>> chunks = jnp.split(x, 3)
>>> print(*chunks)
[1 2 3] [4 5 6] [7 8 9]
Split into sections by index:
>>> chunks = jnp.split(x, [2, 7]) # [x[0:2], x[2:7], x[7:]]
>>> print(*chunks)
[1 2] [3 4 5 6 7] [8 9]
Splitting a two-dimensional array along axis 1:
>>> x = jnp.array([[1, 2, 3, 4],
... [5, 6, 7, 8]])
>>> x1, x2 = jnp.split(x, 2, axis=1)
>>> print(x1)
[[1 2]
[5 6]]
>>> print(x2)
[[3 4]
[7 8]]
See also:
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
to be an integer that does not evenly divide the size of the array.
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
"""
return _split("split", ary, indices_or_sections, axis=axis)

def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, int | ArrayLike], list[Array]]:
@util.implements(getattr(np, op), update_doc=False)
def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
# for 1-D array, hsplit becomes vsplit
nonlocal axis
util.check_arraylike(op, ary)
a = asarray(ary)
if axis == 1 and len(a.shape) == 1:
axis = 0
return _split(op, ary, indices_or_sections, axis=axis)
return f

vsplit = _split_on_axis("vsplit", axis=0)
hsplit = _split_on_axis("hsplit", axis=1)
dsplit = _split_on_axis("dsplit", axis=2)
def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
"""Split an array into sub-arrays vertically.
JAX implementation of :func:`numpy.vsplit`.
Refer to the documentation of :func:`jax.numpy.split` for details; ``vsplit`` is
equivalent to ``split`` with ``axis=0``.
Examples:
1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> x1, x2 = jnp.vsplit(x, 2)
>>> print(x1, x2)
[1 2 3] [4 5 6]
2D array:
>>> x = jnp.array([[1, 2, 3, 4],
... [5, 6, 7, 8]])
>>> x1, x2 = jnp.vsplit(x, 2)
>>> print(x1, x2)
[[1 2 3 4]] [[5 6 7 8]]
See also:
- :func:`jax.numpy.split`: split an array along any axis.
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
to be an integer that does not evenly divide the size of the array.
"""
return _split("vsplit", ary, indices_or_sections, axis=0)


def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
"""Split an array into sub-arrays horizontally.
JAX implementation of :func:`numpy.hsplit`.
Refer to the documentation of :func:`jax.numpy.split` for details. ``hsplit`` is
equivalent to ``split`` with ``axis=1``, or ``axis=0`` for one-dimensional arrays.
Examples:
1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> x1, x2 = jnp.hsplit(x, 2)
>>> print(x1, x2)
[1 2 3] [4 5 6]
2D array:
>>> x = jnp.array([[1, 2, 3, 4],
... [5, 6, 7, 8]])
>>> x1, x2 = jnp.hsplit(x, 2)
>>> print(x1)
[[1 2]
[5 6]]
>>> print(x2)
[[3 4]
[7 8]]
See also:
- :func:`jax.numpy.split`: split an array along any axis.
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
to be an integer that does not evenly divide the size of the array.
"""
util.check_arraylike("hsplit", ary)
a = asarray(ary)
return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1)


def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
"""Split an array into sub-arrays depth-wise.
JAX implementation of :func:`numpy.dsplit`.
Refer to the documentation of :func:`jax.numpy.split` for details. ``dsplit`` is
equivalent to ``split`` with ``axis=2``.
Examples:
>>> x = jnp.arange(12).reshape(3, 1, 4)
>>> print(x)
[[[ 0 1 2 3]]
<BLANKLINE>
[[ 4 5 6 7]]
<BLANKLINE>
[[ 8 9 10 11]]]
>>> x1, x2 = jnp.dsplit(x, 2)
>>> print(x1)
[[[0 1]]
<BLANKLINE>
[[4 5]]
<BLANKLINE>
[[8 9]]]
>>> print(x2)
[[[ 2 3]]
<BLANKLINE>
[[ 6 7]]
<BLANKLINE>
[[10 11]]]
See also:
- :func:`jax.numpy.split`: split an array along any axis.
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
- :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections``
to be an integer that does not evenly divide the size of the array.
"""
return _split("dsplit", ary, indices_or_sections, axis=2)


@util.implements(np.array_split)
def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
axis: int = 0) -> list[Array]:
"""Split an array into sub-arrays.
JAX implementation of :func:`numpy.array_split`.
Refer to the documentation of :func:`jax.numpy.split` for details; ``array_split``
is equivalent to ``split``, but allows integer ``indices_or_sections`` which does
not evenly divide the split axis.
Examples:
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> chunks = jnp.array_split(x, 4)
>>> print(*chunks)
[1 2 3] [4 5] [6 7] [8 9]
See also:
- :func:`jax.numpy.split`: split an array along any axis.
- :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0
- :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1
- :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2
"""
return _split("array_split", ary, indices_or_sections, axis=axis)


Expand Down
3 changes: 2 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6289,6 +6289,7 @@ def test_lax_numpy_docstrings(self):
unimplemented = ['fromfile', 'fromiter']
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
'amax', 'amin', 'around', 'divide', 'round_']
skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split']

for name in dir(jnp):
if name.startswith('_') or name in unimplemented:
Expand All @@ -6313,7 +6314,7 @@ def test_lax_numpy_docstrings(self):
raise Exception(f"jnp.{name} does not have a wrapped docstring.")
elif name in aliases:
assert "Alias of" in obj.__doc__
else:
elif name not in skip_args_check:
# Other functions should have nontrivial docs including "Args" and "Returns".
doc = obj.__doc__
self.assertNotEmpty(doc)
Expand Down

0 comments on commit 6bb4528

Please sign in to comment.