From c5bc2412a74837634e2931eeb73b31cba4575cee Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Tue, 10 Sep 2024 20:56:49 +0530 Subject: [PATCH] Improve doc for jnp.trim_zeros --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9586ede8e127..820a05f56548 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6375,8 +6375,28 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res -@util.implements(np.trim_zeros) def trim_zeros(filt, trim='fb'): + """Trim leading and/or trailing zeros of the input array. + + JAX implementation of :func:`numpy.trim_zeros`. + + Args: + filt: input array. Must have ``filt.ndim == 1``. + trim: string, optional, default = ``fb``. Specifies from which end the input + is trimmed. + + - ``f`` - trims only the leading zeros. + - ``b`` - trims only the trailing zeros. + - ``fb`` - trims both leading and trailing zeros. + + Returns: + An array containig the trimmed input with same dtype as ``filt``. + + Examples: + >>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) + >>> jnp.trim_zeros(x) + Array([2, 0, 1, 4, 3], dtype=int32) + """ filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros()") nz = (filt == 0)