Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas GPU] The behavior of jnp.sign(jnp.nan) in Pallas GPU does not match that of JAX #23504

Closed
ayaka14732 opened this issue Sep 9, 2024 · 0 comments · Fixed by #23523
Closed
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Collaborator

Description

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np

def test_sign():
    @functools.partial(
        pl.pallas_call,
        out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
    )
    def kernel(x_ref, o_ref):
        o_ref[...] = jnp.sign(x_ref[...])

    val = jnp.nan
    x = jnp.full((8, 128), val, dtype=jnp.float32)
    out = kernel(x)
    expected = jnp.sign(x)
    np.testing.assert_array_equal(out, expected)

if __name__ == '__main__':
    test_sign()

Expected behavior:

The arrays are equal, so there will be no error.

Actual behaviour:

Traceback (most recent call last):
  File "/home/ayx/jax/1.py", line 22, in <module>
    test_sign()
  File "/home/ayx/jax/1.py", line 19, in test_sign
    np.testing.assert_array_equal(out, expected)
  File "/home/ayx/venv/lib/python3.12/site-packages/numpy/_utils/__init__.py", line 85, in wrapper
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1025, in assert_array_equal
    assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 780, in assert_array_compare
    flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 749, in func_assert_same_pos
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

nan location mismatch:
 ACTUAL: array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],...
 DESIRED: array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],...

We can have 2 different solutions:

  1. Add jnp.nan support for jnp.sign()
  2. Document this behavior

(1) would allow us to match the JAX behavior, but would be less performant because it requires inserting jnp.where() for every input.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.32.dev20240830+4342c0c0f
jaxlib: 0.4.31
numpy:  2.1.0
python: 3.12.3 (main, Jul 31 2024, 17:43:48) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ayx-gpu-test-3.us-central1-f.c.jax-dev.internal', release='6.8.0-1013-gcp', version='#14-Ubuntu SMP Thu Aug  8 23:18:23 UTC 2024', machine='x86_64')


$ nvidia-smi
Mon Sep  9 11:36:35 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             59W /  400W |     429MiB /  40960MiB |      3%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      6300      C   python                                        416MiB |
+-----------------------------------------------------------------------------------------+
@ayaka14732 ayaka14732 added the bug Something isn't working label Sep 9, 2024
@ayaka14732 ayaka14732 self-assigned this Sep 9, 2024
copybara-service bot pushed a commit that referenced this issue Sep 9, 2024
… test case for `jnp.sign` into the general test

This PR is similar to #23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes #23504

PiperOrigin-RevId: 672481586
copybara-service bot pushed a commit that referenced this issue Sep 9, 2024
… test case for `jnp.sign` into the general test

This PR is similar to #23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes #23504

PiperOrigin-RevId: 672481586
copybara-service bot pushed a commit that referenced this issue Sep 9, 2024
… test case for `jnp.sign` into the general test

This PR is similar to #23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes #23504

PiperOrigin-RevId: 672481586
copybara-service bot pushed a commit that referenced this issue Sep 9, 2024
… test case for `jnp.sign` into the general test

This PR is similar to #23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes #23504

PiperOrigin-RevId: 672481586
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant