From 97cc881b9c46c1f68ca5e052034690389281e341 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 9 Sep 2024 03:24:06 -0700 Subject: [PATCH] [Pallas GPU] Fix the behavior of `jnp.sign(jnp.nan)` and move the TPU test case for `jnp.sign` into the general test This PR is similar to https://github.com/google/jax/pull/23192, which moves TPU test case for `lax.erf_inv` into the general test Fixes https://github.com/google/jax/issues/23504 PiperOrigin-RevId: 672481586 --- jax/_src/pallas/mosaic/lowering.py | 18 ++------ jax/_src/pallas/triton/lowering.py | 12 ++---- jax/_src/pallas/utils.py | 14 +++++++ tests/pallas/ops_test.py | 66 +++++++++++++++--------------- 4 files changed, 52 insertions(+), 58 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c8f910e185f3..bd897deb3d1f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1838,22 +1838,10 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): skip_mlir_conversions.add(lax.neg_p) -def _sign_lowering_helper(x): - if jnp.issubdtype(x.dtype, jnp.unsignedinteger): - return (x != 0).astype(x.dtype) - - if jnp.issubdtype(x.dtype, jnp.integer): - return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) - - if jnp.issubdtype(x.dtype, jnp.floating): - out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype) - return jnp.where(jnp.isnan(x), jnp.nan, out) - - raise NotImplementedError - - def _sign_lowering_rule(ctx: LoweringRuleContext, x): - return lower_fun(_sign_lowering_helper, multiple_results=False)(ctx, x) + return lower_fun( + pallas_utils.sign_lowering_helper, multiple_results=False, + )(ctx, x) lowering_rules[lax.sign_p] = _sign_lowering_rule diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 446d87a5f347..dec48847520a 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1359,15 +1359,9 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return _floordiv(x, y, signed=signed) -@register_lowering(lax.sign_p) -def _sign_lowering_rule(ctx: LoweringRuleContext, x): - [x_aval] = ctx.avals_in - signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) - zero = _full(x.type, 0) - return _sub( - _cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), - _cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), - ) +register_lowering(lax.sign_p)( + lower_fun(pallas_utils.sign_lowering_helper, multiple_results=False) +) @register_lowering(lax.iota_p) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 6fc816e27b53..cfca0769d13d 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -210,3 +210,17 @@ def erf_inv_32_lowering_helper(x): p = c + p * w return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) + + +def sign_lowering_helper(x): + if jnp.issubdtype(x.dtype, jnp.unsignedinteger): + return (x != 0).astype(x.dtype) + + if jnp.issubdtype(x.dtype, jnp.integer): + return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) + + if jnp.issubdtype(x.dtype, jnp.floating): + out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype) + return jnp.where(jnp.isnan(x), jnp.nan, out) + + raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}") diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 84844151eb63..52a1d62a63a3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -669,6 +669,38 @@ def run(interpret=False): actual = run(False) self.assertAllClose(actual, expected) + SIGN_PARAMS = [ + (jnp.int32, (-3, 0, 5)), + (jnp.uint32, (0, 5)), + (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), + (jnp.float64, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), + ] + + @parameterized.named_parameters( + (f"{dtype.__name__}_{value}", dtype, value) + for dtype, values in SIGN_PARAMS + for value in values + ) + def test_sign(self, dtype, value): + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: + self.skipTest("float64 is not supported on TPU") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sign(x_ref[...]) + + with contextlib.ExitStack() as stack: + if jnp.dtype(dtype).itemsize == 8: + stack.enter_context(config.enable_x64(True)) + + x = jnp.full((8, 128,), value, dtype=dtype) + out = kernel(x) + expected = jnp.sign(x) + np.testing.assert_array_equal(out, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True @@ -1614,39 +1646,5 @@ class PallasPrimitivesInterpretTest(PallasPrimitivesTest): INTERPRET = True -class TpuOpsTest(PallasBaseTest): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - - SIGN_PARAMS = [ - (jnp.int32, (-3, 0, 5)), - (jnp.uint32, (0, 5)), - (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - ] - - @parameterized.named_parameters( - (f"{dtype.__name__}_{value}", dtype, value) - for dtype, values in SIGN_PARAMS - for value in values - ) - def test_sign(self, dtype, value): - @jax.jit - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 128), dtype), - ) - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sign(x_ref[...]) - - x = jnp.full((8, 128,), value, dtype=dtype) - out = kernel(x) - expected = jnp.sign(x) - np.testing.assert_array_equal(out, expected) - - if __name__ == "__main__": absltest.main()