diff --git a/tests/memories_test.py b/tests/memories_test.py index 9b8b990d674b..affe5de99644 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1388,6 +1388,38 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') + def test_compute_offload_with_donation(self): + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + p_sharding = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind="pinned_host" + ) + + @compute_on("device_host") + @jax.jit + def host_fn(x_in, y_in): + return x_in * x_in, y_in + y_in + + def test_fn(x_in, y_in): + x_out, y_out = host_fn(x_in, y_in) + return x_out, y_out + + x = jnp.arange(0, 1024, dtype=jnp.float32) + y = jnp.arange(0, 1024, dtype=jnp.float32) + y = jax.device_put(y, p_sharding) + + x1 = jnp.arange(0, 1024, dtype=jnp.float32) + y1 = jnp.arange(0, 1024, dtype=jnp.float32) + + jit_fn = jax.jit( + test_fn, + in_shardings=(sharding, p_sharding), + out_shardings=(sharding, p_sharding), + donate_argnums=(0, 1), + ) + x_out, y_out = jit_fn(x, y) + self.assertArraysEqual(x_out, x1 * x1) + self.assertArraysEqual(y_out, y1 + y1) + class ActivationOffloadingTest(jtu.JaxTestCase):