Skip to content

Commit

Permalink
Disable two lax_scipy_test testcases that fail on TPU v6e.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672973757
  • Loading branch information
hawkinsp authored and jax authors committed Sep 10, 2024
1 parent 062a69a commit 1b2ba9d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
# Special case v5e until the name is updated in device_kind
if expected_version == "v5e":
return "v5 lite" in device_kind
elif expected_version == "v6e":
return "v6 lite" in device_kind
return expected_version in device_kind

def is_cuda_compute_capability_at_least(capability: str) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def scipy_fun(z):
dtype=float_dtypes,
)
def testLpmn(self, l_max, shape, dtype):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]

Expand Down Expand Up @@ -442,6 +444,8 @@ def testSphHarmOrderOneDegreeOne(self):
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
"""Tests against JIT compatibility and Numpy."""
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
n_max = l_max
shape = (num_z,)
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
Expand Down

0 comments on commit 1b2ba9d

Please sign in to comment.