diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c6cacef673e..ef6a30400d30 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5194,9 +5194,50 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. -@util.implements(np.frombuffer) def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: + r"""Convert a buffer into a 1-D JAX array. + + JAX implementation of :func:`numpy.frombuffer`. + + Args: + buffer: an object containing the data. It must be either a bytes object with + a length that is an integer multiple of the dtype element size, or + it must be an object exporting the `Python buffer interface`_. + dtype: optional. Desired data type for the array. Default is ``float64``. + This specifes the dtype used to parse the buffer, but note that after parsing, + 64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64`` + flag is set to ``False``. + count: optional integer specifying the number of items to read from the buffer. + If -1 (default), all items from the buffer are read. + offset: optional integer specifying the number of bytes to skip at the beginning + of the buffer. Default is 0. + + Returns: + A 1-D JAX array representing the interpreted data from the buffer. + + See also: + - :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array. + + Examples: + Using a bytes buffer: + + >>> buf = b"\x00\x01\x02\x03\x04" + >>> jnp.frombuffer(buf, dtype=jnp.uint8) + Array([0, 1, 2, 3, 4], dtype=uint8) + >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) + Array([1, 2, 3, 4], dtype=uint8) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [0, 1, 2, 3, 4]) + >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) + Array([0, 1, 2, 3, 4], dtype=int32) + + .. _Python buffer interface: https://docs.python.org/3/c-api/buffer.html + """ return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))