diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index b110b8f55c8c..9561b3435e50 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1512,6 +1512,8 @@ def kernel(x_ref, y_ref, o_ref): @parameterized.parameters("float32", "float64") def test_nextafter(self, dtype): + if jtu.test_device_matches(["tpu"]) and dtype == "float64": + self.skipTest("float64 disabled on TPU.") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 )