Skip to content

Commit

Permalink
[Pallas/TPU] Add API for megacore partitioning of pipelines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644184524
  • Loading branch information
sharadmv authored and jax authors committed Jun 18, 2024
1 parent fb68f34 commit 701c63e
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 9 deletions.
99 changes: 91 additions & 8 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Module for emitting custom TPU pipelines within a Pallas call."""

from __future__ import annotations

import dataclasses
import enum
import functools
Expand All @@ -24,6 +26,7 @@
import jax
from jax import lax
from jax import tree_util
from jax._src import util as jax_util
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
Expand Down Expand Up @@ -87,15 +90,16 @@ def _grid_size(grid):
return size


def _get_indices(step, grid):
def _get_indices(step, grid, offsets):
"""Get indices for a given step and grid."""
extended_grid = grid + (1,)
strides = tuple(
itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1]
return tuple(
indices = tuple(
lax.div(lax.rem(step, a), b)
for a, b in zip(strides[:-1], strides[1:])
)
return tuple(a + b for a, b in zip(indices, offsets, strict=True))


class BufferType(enum.Enum):
Expand Down Expand Up @@ -350,8 +354,9 @@ class Scheduler:
"""Sequences input and output copies and waits for a pipeline."""

def __init__(self,
step,
grid,
step: jax.Array,
grid: tuple[int | jax.Array, ...],
grid_offsets: tuple[int | jax.Array, ...],
first_cycle=None,
last_cycle=None,
init_accumulators=None,
Expand All @@ -361,6 +366,7 @@ def __init__(self,
Args:
step: inner step number.
grid: pallas grid for BufferedRefs.
grid_offsets: offsets for grid indices (used for megacore).
first_cycle: whether this is the first invocation of the pipeline.
last_cycle: whether this is the last invocation of the pipeline.
init_accumulators: do we zero-initialize accumulator state for this
Expand Down Expand Up @@ -388,9 +394,13 @@ def __init__(self,
self.next_step = _mod(step + 1, self.num_steps)

# Derived grid indices for present, previous, and next steps.
self.indices = _get_indices(step, grid)
self.prev_indices = _get_indices(self.prev_step, self.grid)
self.next_indices = _get_indices(self.next_step, self.grid)
self.indices = _get_indices(step, grid, grid_offsets)
self.prev_indices = _get_indices(
self.prev_step, grid, grid_offsets
)
self.next_indices = _get_indices(
self.next_step, grid, grid_offsets
)

def grid_env(self):
return pallas_core.grid_env(
Expand Down Expand Up @@ -628,13 +638,79 @@ def make_output_bref(out_spec, out_ref, accumulate):
return (*in_brefs, *out_brefs)


class GridDimensionSemantics:
pass
PARALLEL = GridDimensionSemantics()
ARBITRARY = GridDimensionSemantics()


def _partition_grid(
grid: tuple[int | jax.Array, ...],
core_axis: int | None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None,
) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]:
if core_axis is None:
# We aren't partitioning the grid
return grid, (0,) * len(grid)
num_cores = pl.num_programs(core_axis)
# Check that num_cores is statically known
if not isinstance(num_cores, int):
raise NotImplementedError(
f"Cannot partition grid over dynamic number of cores: {core_axis=}"
)
if num_cores == 1:
# We aren't partitioning the grid
return grid, (0,) * len(grid)

# If dimension_semantics aren't provided, we assume it is all arbitrary.
if dimension_semantics is None:
dimension_semantics = (ARBITRARY,) * len(grid)
if len(dimension_semantics) != len(grid):
raise ValueError("dimension_semantics must be the same length as grid.")

parallel_dimensions = {i for i, d in enumerate(dimension_semantics)
if d == PARALLEL}
# If there are no parallel dimensions, we can't partition the grid
if not parallel_dimensions:
# TODO(sharadmv): enable running kernel on just one core
raise NotImplementedError(
"Cannot partition over cores without parallel grid dimensions:"
f" {dimension_semantics=}"
)

# Try to find a divisible dimension to partition the grid on
divisible_dimensions = {
i for i in parallel_dimensions
if isinstance(grid[i], int) and grid[i] % num_cores == 0
}
if not divisible_dimensions:
# TODO(sharadmv): enable uneven grid partitioning
raise NotImplementedError(
f"Uneven partitioning of grid not supported: {grid=}, {num_cores=}"
)
first_divisible_dimension, *_ = [
i for i in range(len(dimension_semantics)) if i in divisible_dimensions
]
partitioned_dim_size = grid[first_divisible_dimension] // num_cores
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
new_grid = jax_util.tuple_update(
grid, first_divisible_dimension, partitioned_dim_size
)
offsets = jax_util.tuple_update(
(0,) * len(grid), first_divisible_dimension, partitioned_dim_offset
)
return new_grid, offsets


def emit_pipeline(
body,
*,
grid,
grid: tuple[int | jax.Array, ...],
in_specs=None,
out_specs=None,
should_accumulate_out=False,
core_axis: int | None = None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None
):
"""Creates a function to emit a manual pallas pipeline.
Expand All @@ -653,7 +729,13 @@ def emit_pipeline(
out_specs: output pallas block specs
should_accumulate_out: booleans to indicate which outputs should be treated
as accumulators.
core_axis: optional int, indicates whether or not to partition the grid
along the core axis.
dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL
or ARBITRARY).
"""
grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics)

num_steps = _grid_size(grid)
if not isinstance(in_specs, (list, tuple)):
in_specs = (in_specs,)
Expand Down Expand Up @@ -737,6 +819,7 @@ def loop_body(step, _):
scheduler = Scheduler(
step,
grid,
grid_offsets=grid_offsets,
first_cycle=first_cycle,
last_cycle=last_cycle,
init_accumulators=init_accumulators)
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.pipeline import ARBITRARY
from jax._src.pallas.mosaic.pipeline import PARALLEL
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast
Expand Down
87 changes: 86 additions & 1 deletion tests/pallas/pallas_pipeline_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def emit_pipeline(should_accumulate_out):
np.testing.assert_allclose(z, jnp.dot(x, y) + jnp.dot(x, y))


class PallasCallColectivePipelineTest(parameterized.TestCase):
class PallasCallCollectivePipelineTest(parameterized.TestCase):

def setUp(self):
if jax.device_count() < 2:
Expand Down Expand Up @@ -1263,5 +1263,90 @@ def reference(x, y):
)


class PallasCallMegacoreTest(parameterized.TestCase):

def setUp(self):
if not jtu.is_device_tpu_at_least(4):
self.skipTest('Only works with TPU v4')

super().setUp()

def test_megacore_mul(self):
x = jax.random.uniform(jax.random.key(0), (512, 512))

def matmul_pipeline(x_ref, y_ref):
y_ref[...] = x_ref[...] * 2

def matmul_kernel(x_ref, y_ref):
pltpu.emit_pipeline(
matmul_pipeline,
grid=(4, 4),
in_specs=[
pl.BlockSpec(lambda i, j: (i, j), (128, 128)),
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (128, 128)),
core_axis=0,
dimension_semantics=(pltpu.ARBITRARY, pltpu.PARALLEL)
)(x_ref, y_ref)

num_cores = jax.devices()[0].num_cores
func = pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
grid=(num_cores,),
compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))),
)
np.testing.assert_allclose(func(x), x * 2)

@parameterized.parameters(
(1024, 1024, 1024, 256, 512, 256),
(768, 1024, 1024, 256, 512, 256),
(1024, 1024, 768, 256, 512, 256),
)
def test_megacore_matmul(self, m, k, n, bm, bk, bn):
k1, k2 = jax.random.split(jax.random.key(42))
x = jax.random.uniform(k1, (m, k))
y = jax.random.uniform(k2, (k, n))

def matmul_pipeline(x_ref, y_ref, z_ref):
@pl.when(pl.program_id(2) == 0)
def _():
z_ref[...] = jnp.zeros_like(z_ref)
z_ref[...] += x_ref[...] @ y_ref[...]

def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn):
m, k = x_ref.shape
_, n = y_ref.shape
pltpu.emit_pipeline(
matmul_pipeline,
grid=(m // bm, n // bn, k // bk),
in_specs=[
pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)),
pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)),
],
out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)),
core_axis=0,
dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL, pltpu.ARBITRARY)
)(x_ref, y_ref, z_ref)

num_cores = jax.devices()[0].num_cores
func = pl.pallas_call(
functools.partial(matmul_kernel, bm=bm, bk=bk, bn=bn),
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
pl.BlockSpec(memory_space=pltpu.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
grid=(num_cores,),
compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))),
)
np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 701c63e

Please sign in to comment.