Skip to content

Commit

Permalink
Update the multi-process note in pjit's docstring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632160561
  • Loading branch information
yashk2810 authored and jax authors committed May 9, 2024
1 parent 2be3f6d commit 671fb12
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,19 +728,10 @@ def pjit(
processes run the same :func:`~pjit`'d function in the same order.
When running in this configuration, the mesh should contain devices across
all processes. However, any input argument dimensions partitioned over
multi-process mesh axes should be of size equal to the corresponding *local*
mesh axis size, and outputs will be similarly sized according to the local
mesh. ``fun`` will still be executed across *all* devices in the mesh,
all processes. All inputs arguments must be globally shaped.
``fun`` will still be executed across *all* devices in the mesh,
including those from other processes, and will be given a global view of the
data spread across multiple processes as a single array. However, outside
of :func:`~pjit` every process only "sees" its local piece of the input and output,
corresponding to its local sub-mesh.
This means that each process's participating local devices must form a
_contiguous_ local sub-mesh within the full global mesh. A contiguous
sub-mesh is one where all of its devices are adjacent within the global
mesh, and form a rectangular prism.
data spread across multiple processes as a single array.
The SPMD model also requires that the same multi-process :func:`~pjit`'d
functions must be run in the same order on all processes, but they can be
Expand Down

0 comments on commit 671fb12

Please sign in to comment.