Skip to content

Commit

Permalink
jnp.searchsorted: support sorter argument
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 29, 2024
1 parent cc0a20f commit ef40476
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
13 changes: 8 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7168,22 +7168,25 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt
'sort' is often more performant on accelerator backends like GPU and TPU
(particularly when ``v`` is very large), and 'compare_all' can be most performant
when ``a`` is very small."""))
@partial(jit, static_argnames=('side', 'sorter', 'method'))
@partial(jit, static_argnames=('side', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: None = None, *, method: str = 'scan') -> Array:
util.check_arraylike("searchsorted", a, v)
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array:
if sorter is None:
util.check_arraylike("searchsorted", a, v)
else:
util.check_arraylike("searchsorted", a, v, sorter)
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'. "
"Expected one of ['left', 'right'].")
if method not in ['scan', 'scan_unrolled', 'sort', 'compare_all']:
raise ValueError(
f"{method!r} is an invalid value for keyword 'method'. "
"Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].")
if sorter is not None:
raise NotImplementedError("sorter is not implemented")
if ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
a, v = util.promote_dtypes(a, v)
if sorter is not None:
a = a[sorter]
dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
if len(a) == 0:
return zeros_like(v, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ s_ = _np.s_
save = _np.save
savez = _np.savez
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = ...,
sorter: None = ..., *, method: str = ...) -> Array: ...
sorter: ArrayLike | None = ..., *, method: str = ...) -> Array: ...
def select(
condlist: Sequence[ArrayLike],
choicelist: Sequence[ArrayLike],
Expand Down
15 changes: 10 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2660,13 +2660,18 @@ def np_fun(arg):
side=['left', 'right'],
dtype=number_dtypes,
method=['sort', 'scan', 'scan_unrolled', 'compare_all'],
use_sorter=[True, False],
)
def testSearchsorted(self, ashape, vshape, side, dtype, method):
def testSearchsorted(self, ashape, vshape, side, dtype, method, use_sorter):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)]
def np_fun(a, v):
return np.searchsorted(a, v, side=side).astype('int32')
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method)
def args_maker():
a = rng(ashape, dtype)
v = rng(vshape, dtype)
return (a, v, np.argsort(a)) if use_sorter else (np.sort(a), v)
def np_fun(a, v, sorter=None):
return np.searchsorted(a, v, side=side, sorter=sorter).astype('int32')
def jnp_fun(a, v, sorter=None):
return jnp.searchsorted(a, v, side=side, method=method, sorter=sorter)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

Expand Down

0 comments on commit ef40476

Please sign in to comment.