Skip to content

Commit

Permalink
frombuffer_docstring_added
Browse files Browse the repository at this point in the history
description_changed_examp_added

doc_byte_fixed
  • Loading branch information
selamw1 committed Sep 13, 2024
1 parent 0daca46 commit 8db394e
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5194,9 +5194,47 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:

# General np.from* style functions mostly delegate to numpy.

@util.implements(np.frombuffer)
def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float,
count: int = -1, offset: int = 0) -> Array:
r"""Convert a buffer into a 1-D JAX array.
JAX implementation of :func:`numpy.frombuffer`.
Args:
buffer: an object containing the data. It must be either a bytes object
with a length that is an integer multiple of the dtype element size, or
it must be an object exporting the Python buffer interface.
dtype: optional. Desired data type for the array. Default is ``float64``.
Since we delegate directly to NumPy, but the result may be down-cast
to ``float32`` depending on the state of the X64 flag.
count: optional integer specifying the number of items to read from the buffer.
If -1 (default), all items from the buffer are read.
offset: optional integer specifying the number of bytes to skip at the beginning
of the buffer. Default is 0.
Returns:
A 1-D JAX array representing the interpreted data from the buffer.
See also:
- :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array.
Examples:
Using a bytes buffer:
>>> buf = b"\x00\x01\x02\x03\x04"
>>> jnp.frombuffer(buf, dtype=jnp.uint8)
Array([0, 1, 2, 3, 4], dtype=uint8)
>>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
Array([1, 2, 3, 4], dtype=uint8)
Constructing a JAX array via the Python buffer interface, using Python's
built-in :mod:`array` module.
>>> from array import array
>>> pybuffer = array('i', [0, 1, 2, 3, 4])
>>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
Array([0, 1, 2, 3, 4], dtype=int32)
"""
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))


Expand Down

0 comments on commit 8db394e

Please sign in to comment.