Skip to content

Commit

Permalink
jnp.searchsorted: support sorter argument
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 28, 2024
1 parent db11842 commit 0da5eff
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
11 changes: 7 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7037,22 +7037,25 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt
'sort' is often more performant on accelerator backends like GPU and TPU
(particularly when ``v`` is very large), and 'compare_all' can be most performant
when ``a`` is very small."""))
@partial(jit, static_argnames=('side', 'sorter', 'method'))
@partial(jit, static_argnames=('side', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: None = None, *, method: str = 'scan') -> Array:
util.check_arraylike("searchsorted", a, v)
if sorter is None:
util.check_arraylike("searchsorted", a, v)
else:
util.check_arraylike("searchsorted", a, v, sorter)
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'. "
"Expected one of ['left', 'right'].")
if method not in ['scan', 'scan_unrolled', 'sort', 'compare_all']:
raise ValueError(
f"{method!r} is an invalid value for keyword 'method'. "
"Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].")
if sorter is not None:
raise NotImplementedError("sorter is not implemented")
if ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
a, v = util.promote_dtypes(a, v)
if sorter is not None:
a = a[sorter]
dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
if len(a) == 0:
return zeros_like(v, dtype=dtype)
Expand Down
20 changes: 15 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2660,13 +2660,23 @@ def np_fun(arg):
side=['left', 'right'],
dtype=number_dtypes,
method=['sort', 'scan', 'scan_unrolled', 'compare_all'],
use_sorter=[True, False],
)
def testSearchsorted(self, ashape, vshape, side, dtype, method):
def testSearchsorted(self, ashape, vshape, side, dtype, method, use_sorter):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)]
def np_fun(a, v):
return np.searchsorted(a, v, side=side).astype('int32')
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method)
def args_maker():
a = rng(ashape, dtype)
v = rng(vshape, dtype)
if use_sorter:
i = np.argsort(a)
return a, v, i
else:
a.sort()
return a, v
def np_fun(a, v, sorter=None):
return np.searchsorted(a, v, side=side, sorter=sorter).astype('int32')
def jnp_fun(a, v, sorter=None):
return jnp.searchsorted(a, v, side=side, method=method, sorter=sorter)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

Expand Down

0 comments on commit 0da5eff

Please sign in to comment.