Skip to content

Commit

Permalink
Adds test_compute_offload_with_donation in memories_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671410527
  • Loading branch information
Google-ML-Automation authored and jax authors committed Sep 5, 2024
1 parent 8fe99ff commit 97db78b
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,38 @@ def f(x):
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
self.assertEqual(out.sharding.memory_kind, 'device')

def test_compute_offload_with_donation(self):
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
p_sharding = jax.sharding.SingleDeviceSharding(
jax.devices()[0], memory_kind="pinned_host"
)

@compute_on("device_host")
@jax.jit
def host_fn(x_in, y_in):
return x_in * x_in, y_in + y_in

def test_fn(x_in, y_in):
x_out, y_out = host_fn(x_in, y_in)
return x_out, y_out

x = jnp.arange(0, 1024, dtype=jnp.float32)
y = jnp.arange(0, 1024, dtype=jnp.float32)
y = jax.device_put(y, p_sharding)

x1 = jnp.arange(0, 1024, dtype=jnp.float32)
y1 = jnp.arange(0, 1024, dtype=jnp.float32)

jit_fn = jax.jit(
test_fn,
in_shardings=(sharding, p_sharding),
out_shardings=(sharding, p_sharding),
donate_argnums=(0, 1),
)
x_out, y_out = jit_fn(x, y)
self.assertArraysEqual(x_out, x1 * x1)
self.assertArraysEqual(y_out, y1 + y1)


class ActivationOffloadingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 97db78b

Please sign in to comment.