Skip to content

Commit

Permalink
Make make_array_from_process_local_data go via device_put if ther…
Browse files Browse the repository at this point in the history
…e is only 1 process.

PiperOrigin-RevId: 677046372
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Sep 21, 2024
1 parent d63afd8 commit 5961f60
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
9 changes: 6 additions & 3 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,17 +892,20 @@ def make_array_from_process_local_data(
setting it to (4, 4) in this case.
Args:
sharding: sharding of the global tensor.
local_data: data on the host to be placed on local devices. Each
sharding: Sharding of the global array.
local_data: Data on the host to be placed on local devices. Each
dimension should either match global_shape, or match
num_addressable_indices(dim).
global_shape: the target shape of the global tensor. If None,
global_shape: The target shape of the global array. If None,
will infer from local_data and sharding.
Returns:
Tensor that will have sharding=sharding and of shape global_shape.
"""
# pyformat: enable
if xla_bridge.process_count() == 1:
return api.device_put(local_data, sharding)

# TODO(sandler): consider supporting partially specified global_shape or
# making local_to_global_shape available in the api.
local_shape = local_data.shape
Expand Down
18 changes: 6 additions & 12 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,18 +822,12 @@ def test_make_array_from_callback_global_array(self):
self.assertEqual(out2.sharding, sharding2)

def test_make_array_from_process_data_single_host_data_sharding(self):
data = np.ones((1, 512))
mesh = jtu.create_mesh((1, 1), ('x', 'unused'))
sharding_spec = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec('x')
)
global_shape = data.shape
result = jax.make_array_from_process_local_data(
sharding_spec, data, global_shape
)
self.assertIsInstance(result, jax.Array)
self.assertEqual(result.shape, data.shape)
self.assertEqual(result.sharding, sharding_spec)
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
data = np.ones((256, 512))
s = jax.NamedSharding(mesh, P('x'))
result = jax.make_array_from_process_local_data(s, data)
self.assertArraysEqual(result, data)
self.assertEqual(result.sharding, s)

class ShardingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 5961f60

Please sign in to comment.