diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d53684743b94..d2424fb7047f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -728,19 +728,10 @@ def pjit( processes run the same :func:`~pjit`'d function in the same order. When running in this configuration, the mesh should contain devices across - all processes. However, any input argument dimensions partitioned over - multi-process mesh axes should be of size equal to the corresponding *local* - mesh axis size, and outputs will be similarly sized according to the local - mesh. ``fun`` will still be executed across *all* devices in the mesh, + all processes. All inputs arguments must be globally shaped. + ``fun`` will still be executed across *all* devices in the mesh, including those from other processes, and will be given a global view of the - data spread across multiple processes as a single array. However, outside - of :func:`~pjit` every process only "sees" its local piece of the input and output, - corresponding to its local sub-mesh. - - This means that each process's participating local devices must form a - _contiguous_ local sub-mesh within the full global mesh. A contiguous - sub-mesh is one where all of its devices are adjacent within the global - mesh, and form a rectangular prism. + data spread across multiple processes as a single array. The SPMD model also requires that the same multi-process :func:`~pjit`'d functions must be run in the same order on all processes, but they can be