Skip to content

Commit

Permalink
Make JAX extract the mesh from an AUTO in/out sharding.
Browse files Browse the repository at this point in the history
Automatic partitioners using JAX+Shardy want to partition models which are fully marked as `AUTO` - so no in/out sharding with a `NamedSharding`. In such a case they weren't seeing the mesh on the MLIR module. This makes sure we extract it from the `AUTO` sharding.

PiperOrigin-RevId: 672881018
  • Loading branch information
bartchr808 authored and jax authors committed Sep 10, 2024
1 parent 7d2f0a7 commit 062a69a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,7 +2217,7 @@ def lower_sharding_computation(
if config.use_shardy_partitioner.value or prim_requires_devices:
for sharding in it.chain(in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
if (mesh_shape_tuple is not None and
mesh_shape_tuple != sharding.mesh.shape_tuple):
raise ValueError(
Expand Down
17 changes: 14 additions & 3 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,17 +5288,28 @@ def f(x):
def test_compile_with_inferred_out_sharding(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = jax.device_put(np.arange(8 * 4).reshape(8, 4),
jax.sharding.NamedSharding(mesh, P('x', 'y')))
NamedSharding(mesh, P('x', 'y')))
y = jax.device_put(np.arange(4 * 16).reshape(4, 16),
jax.sharding.NamedSharding(mesh, P('y')))
NamedSharding(mesh, P('y')))

@jax.jit
def f(x, y):
return x @ y

out = f(x, y)
self.assertArraysEqual(out, x @ y)
self.assertEqual(out.sharding, jax.sharding.NamedSharding(mesh, P('x')))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))

def test_fully_automatic_sharding(self):
mesh = jtu.create_mesh((8,), ('x',))
x = jax.ShapeDtypeStruct((128, 128), jnp.float32)

@jax.jit
def f(x, y):
return x @ y

lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text()
self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str)


if __name__ == '__main__':
Expand Down

0 comments on commit 062a69a

Please sign in to comment.