[Pallas GPU] The behavior of jnp.sign(jnp.nan)
in Pallas GPU does not match that of JAX
#23504
Labels
bug
Something isn't working
Description
Expected behavior:
The arrays are equal, so there will be no error.
Actual behaviour:
We can have 2 different solutions:
jnp.nan
support forjnp.sign()
(1) would allow us to match the JAX behavior, but would be less performant because it requires inserting
jnp.where()
for every input.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: