Skip to content

Commit

Permalink
[Pallas GPU] Fix the behavior of jnp.sign(jnp.nan) and move the TPU…
Browse files Browse the repository at this point in the history
… test case for `jnp.sign` into the general test

This PR is similar to #23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes #23504

PiperOrigin-RevId: 672481586
  • Loading branch information
ayaka14732 authored and jax authors committed Sep 9, 2024
1 parent d6c3625 commit 97cc881
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 58 deletions.
18 changes: 3 additions & 15 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,22 +1838,10 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x):
skip_mlir_conversions.add(lax.neg_p)


def _sign_lowering_helper(x):
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
return (x != 0).astype(x.dtype)

if jnp.issubdtype(x.dtype, jnp.integer):
return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)

if jnp.issubdtype(x.dtype, jnp.floating):
out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype)
return jnp.where(jnp.isnan(x), jnp.nan, out)

raise NotImplementedError


def _sign_lowering_rule(ctx: LoweringRuleContext, x):
return lower_fun(_sign_lowering_helper, multiple_results=False)(ctx, x)
return lower_fun(
pallas_utils.sign_lowering_helper, multiple_results=False,
)(ctx, x)


lowering_rules[lax.sign_p] = _sign_lowering_rule
Expand Down
12 changes: 3 additions & 9 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,15 +1359,9 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
return _floordiv(x, y, signed=signed)


@register_lowering(lax.sign_p)
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
zero = _full(x.type, 0)
return _sub(
_cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
_cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
)
register_lowering(lax.sign_p)(
lower_fun(pallas_utils.sign_lowering_helper, multiple_results=False)
)


@register_lowering(lax.iota_p)
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,17 @@ def erf_inv_32_lowering_helper(x):
p = c + p * w

return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x)


def sign_lowering_helper(x):
if jnp.issubdtype(x.dtype, jnp.unsignedinteger):
return (x != 0).astype(x.dtype)

if jnp.issubdtype(x.dtype, jnp.integer):
return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)

if jnp.issubdtype(x.dtype, jnp.floating):
out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype)
return jnp.where(jnp.isnan(x), jnp.nan, out)

raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}")
66 changes: 32 additions & 34 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,38 @@ def run(interpret=False):
actual = run(False)
self.assertAllClose(actual, expected)

SIGN_PARAMS = [
(jnp.int32, (-3, 0, 5)),
(jnp.uint32, (0, 5)),
(jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
(jnp.float64, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
]

@parameterized.named_parameters(
(f"{dtype.__name__}_{value}", dtype, value)
for dtype, values in SIGN_PARAMS
for value in values
)
def test_sign(self, dtype, value):
if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64:
self.skipTest("float64 is not supported on TPU")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sign(x_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))

x = jnp.full((8, 128,), value, dtype=dtype)
out = kernel(x)
expected = jnp.sign(x)
np.testing.assert_array_equal(out, expected)


class OpsInterpretTest(OpsTest):
INTERPRET = True
Expand Down Expand Up @@ -1614,39 +1646,5 @@ class PallasPrimitivesInterpretTest(PallasPrimitivesTest):
INTERPRET = True


class TpuOpsTest(PallasBaseTest):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Test requires TPU device.")

super().setUp()

SIGN_PARAMS = [
(jnp.int32, (-3, 0, 5)),
(jnp.uint32, (0, 5)),
(jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
]

@parameterized.named_parameters(
(f"{dtype.__name__}_{value}", dtype, value)
for dtype, values in SIGN_PARAMS
for value in values
)
def test_sign(self, dtype, value):
@jax.jit
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sign(x_ref[...])

x = jnp.full((8, 128,), value, dtype=dtype)
out = kernel(x)
expected = jnp.sign(x)
np.testing.assert_array_equal(out, expected)


if __name__ == "__main__":
absltest.main()

0 comments on commit 97cc881

Please sign in to comment.