diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 029f4f1d9942..3a937fa72efc 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7037,10 +7037,13 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt 'sort' is often more performant on accelerator backends like GPU and TPU (particularly when ``v`` is very large), and 'compare_all' can be most performant when ``a`` is very small.""")) -@partial(jit, static_argnames=('side', 'sorter', 'method')) +@partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: None = None, *, method: str = 'scan') -> Array: - util.check_arraylike("searchsorted", a, v) + if sorter is None: + util.check_arraylike("searchsorted", a, v) + else: + util.check_arraylike("searchsorted", a, v, sorter) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") @@ -7048,11 +7051,11 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', raise ValueError( f"{method!r} is an invalid value for keyword 'method'. " "Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].") - if sorter is not None: - raise NotImplementedError("sorter is not implemented") if ndim(a) != 1: raise ValueError("a should be 1-dimensional") a, v = util.promote_dtypes(a, v) + if sorter is not None: + a = a[sorter] dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64 if len(a) == 0: return zeros_like(v, dtype=dtype) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 050f39cedd15..18dec99a7fc9 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2660,13 +2660,23 @@ def np_fun(arg): side=['left', 'right'], dtype=number_dtypes, method=['sort', 'scan', 'scan_unrolled', 'compare_all'], + use_sorter=[True, False], ) - def testSearchsorted(self, ashape, vshape, side, dtype, method): + def testSearchsorted(self, ashape, vshape, side, dtype, method, use_sorter): rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] - def np_fun(a, v): - return np.searchsorted(a, v, side=side).astype('int32') - jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) + def args_maker(): + a = rng(ashape, dtype) + v = rng(vshape, dtype) + if use_sorter: + i = np.argsort(a) + return a, v, i + else: + a.sort() + return a, v + def np_fun(a, v, sorter=None): + return np.searchsorted(a, v, side=side, sorter=sorter).astype('int32') + def jnp_fun(a, v, sorter=None): + return jnp.searchsorted(a, v, side=side, method=method, sorter=sorter) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)