From 8db394e5ef9421454fbddfbae076124516054a15 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 10 Sep 2024 16:16:02 -0700 Subject: [PATCH] frombuffer_docstring_added description_changed_examp_added doc_byte_fixed --- jax/_src/numpy/lax_numpy.py | 40 ++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c6cacef673e..211e56bb451a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5194,9 +5194,47 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. -@util.implements(np.frombuffer) def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: + r"""Convert a buffer into a 1-D JAX array. + + JAX implementation of :func:`numpy.frombuffer`. + + Args: + buffer: an object containing the data. It must be either a bytes object + with a length that is an integer multiple of the dtype element size, or + it must be an object exporting the Python buffer interface. + dtype: optional. Desired data type for the array. Default is ``float64``. + Since we delegate directly to NumPy, but the result may be down-cast + to ``float32`` depending on the state of the X64 flag. + count: optional integer specifying the number of items to read from the buffer. + If -1 (default), all items from the buffer are read. + offset: optional integer specifying the number of bytes to skip at the beginning + of the buffer. Default is 0. + + Returns: + A 1-D JAX array representing the interpreted data from the buffer. + + See also: + - :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array. + + Examples: + Using a bytes buffer: + + >>> buf = b"\x00\x01\x02\x03\x04" + >>> jnp.frombuffer(buf, dtype=jnp.uint8) + Array([0, 1, 2, 3, 4], dtype=uint8) + >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) + Array([1, 2, 3, 4], dtype=uint8) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [0, 1, 2, 3, 4]) + >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) + Array([0, 1, 2, 3, 4], dtype=int32) + """ return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))