Skip to content

Commit

Permalink
filecheck test: use lax.cumsum directly to prevent false-positive
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 4, 2024
1 parent e7d3785 commit 0e6650e
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions tests/filecheck/subcomputations.filecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl import app

import jax
from jax import numpy as jnp
from jax.interpreters import mlir
from jax._src.lib.mlir import ir
import numpy as np
Expand All @@ -39,7 +38,7 @@ def main(_):
# CHECK-NOT: func private @cumsum
@print_ir(np.empty([2, 7], np.int32), np.empty([2, 7], np.int32))
def cumsum_only_once(x, y):
return jnp.cumsum(x) + jnp.cumsum(y)
return jax.lax.cumsum(x) + jax.lax.cumsum(y)

# Test merging modules
# CHECK-LABEL: TEST: merge_modules
Expand Down

0 comments on commit 0e6650e

Please sign in to comment.