Skip to content

Commit

Permalink
Make jnp.negative a ufunc & add unary ufunc tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 18, 2024
1 parent 1d84621 commit 628028f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 30 deletions.
8 changes: 4 additions & 4 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,15 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
"setitem": _unimplemented_setitem,
"copy": _copy,
"deepcopy": _deepcopy,
"neg": ufuncs.negative,
"pos": ufuncs.positive,
"neg": lambda self: ufuncs.negative(self),
"pos": lambda self: ufuncs.positive(self),
"eq": _defer_to_unrecognized_arg("==", ufuncs.equal),
"ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal),
"lt": _defer_to_unrecognized_arg("<", ufuncs.less),
"le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal),
"gt": _defer_to_unrecognized_arg(">", ufuncs.greater),
"ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal),
"abs": ufuncs.abs,
"abs": lambda self: ufuncs.abs(self),
"add": _defer_to_unrecognized_arg("+", ufuncs.add),
"radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True),
"sub": _defer_to_unrecognized_arg("-", ufuncs.subtract),
Expand All @@ -944,7 +944,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
"ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True),
"xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor),
"rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True),
"invert": ufuncs.bitwise_not,
"invert": lambda self: ufuncs.bitwise_not(self),
"lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift),
"rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift),
"rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True),
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def invert(x: ArrayLike, /) -> Array:


@partial(jit, inline=True)
def negative(x: ArrayLike, /) -> Array:
def _negative(x: ArrayLike, /) -> Array:
"""Return element-wise negative values of the input.
JAX implementation of :obj:`numpy.negative`.
Expand Down Expand Up @@ -2221,3 +2221,4 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No
logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce)
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)
negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative)
94 changes: 69 additions & 25 deletions tests/lax_numpy_ufuncs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,35 @@ def _jnp_ufunc_props(name):
jnp_func = getattr(jnp, name)
assert isinstance(jnp_func, jnp.ufunc)
np_func = getattr(np, name)
dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types]
dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types]
return [dict(name=name, dtype=dtype) for dtype in dtypes]


JAX_NUMPY_UFUNCS = [
name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc)
]

BINARY_UFUNCS = [
name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 2
]

UNARY_UFUNCS = [
name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 1
]

JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS
))

BINARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in BINARY_UFUNCS
))

UNARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable(
_jnp_ufunc_props(name) for name in UNARY_UFUNCS
))


broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)]

Expand Down Expand Up @@ -144,12 +161,25 @@ def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape,
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
UNARY_UFUNCS_WITH_DTYPES,
shape=broadcast_compatible_shapes,
)
def test_unary_ufunc_call(self, name, dtype, shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
def test_bimary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
rng = jtu.rand_default(self.rng())
Expand Down Expand Up @@ -177,15 +207,13 @@ def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape,
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
BINARY_UFUNCS_WITH_DTYPES,
lhs_shape=broadcast_compatible_shapes,
rhs_shape=broadcast_compatible_shapes,
)
def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")

rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
Expand Down Expand Up @@ -213,16 +241,15 @@ def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype):
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_ufunc_reduce(self, name, shape, axis, dtype):
def test_binary_ufunc_reduce(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")

jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
np_fun_reduce = partial(np_fun.reduce, axis=axis)

Expand Down Expand Up @@ -266,16 +293,15 @@ def np_fun(arr, where):
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
)
def test_ufunc_reduce_where(self, name, shape, axis, dtype):
def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")

if jnp_fun.identity is None:
self.skipTest("reduce with where requires identity")

Expand Down Expand Up @@ -309,16 +335,14 @@ def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dty
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))]
)
def test_ufunc_accumulate(self, name, shape, axis, dtype):
def test_binary_ufunc_accumulate(self, name, shape, axis, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")

rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand Down Expand Up @@ -355,15 +379,35 @@ def np_fun(x, idx, y):
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
UNARY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_ufunc_at(self, name, shape, idx_shape, dtype):
def test_unary_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)

rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')]

jnp_fun_at = partial(jnp_fun.at, inplace=False)
def np_fun_at(x, idx):
x_copy = x.copy()
np_fun.at(x_copy, idx)
return x_copy

self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker)
self._CompileAndCheck(jnp_fun_at, args_maker)

@jtu.sample_product(
BINARY_UFUNCS_WITH_DTYPES,
shape=nonscalar_shapes,
idx_shape=[(), (2,)],
)
def test_binary_ufunc_at(self, name, shape, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}")

rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0])
Expand Down Expand Up @@ -413,13 +457,13 @@ def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_s
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
JAX_NUMPY_UFUNCS_WITH_DTYPES,
BINARY_UFUNCS_WITH_DTYPES,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
)
def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype):
jnp_fun = getattr(jnp, name)
np_fun = getattr(np, name)
if (jnp_fun.nin, jnp_fun.nout) != (2, 1):
Expand Down

0 comments on commit 628028f

Please sign in to comment.