diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 95b195e114be..45f933a18b6b 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -2148,16 +2148,11 @@ class RmsNormInterpreterTest(PallasTest): class SoftmaxTest(PallasTest): - @parameterized.parameters( - (shape, dtype) - for shape in [(1024, 125), (4, 1024, 125)] - for dtype in (jnp.bfloat16, jnp.float16, jnp.float32) + @parameterized.product( + shape=[(1024, 125), (4, 1024, 125)], + dtype=[jnp.bfloat16, jnp.float16, jnp.float32] ) def test_softmax(self, shape, dtype): - # TODO(bchetioui): add Triton bug reference when filed - if dtype == jnp.bfloat16: - raise absltest.SkipTest("Disabled due to Triton lowering bug") - x = jax.random.normal(random.key(0), shape, dtype=dtype) atol, rtol = { @@ -2166,9 +2161,11 @@ def test_softmax(self, shape, dtype): jnp.float32: (1e-7, 1e-6), }[dtype] + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/google/jax/issues/11014. np.testing.assert_allclose( - softmax.softmax(x, axis=-1), - jax.nn.softmax(x, axis=-1), + softmax.softmax(x, axis=-1).astype(jnp.float32), + jax.nn.softmax(x, axis=-1).astype(jnp.float32), atol=atol, rtol=rtol, )