Skip to content

Commit

Permalink
Estimate the amount of required scratch SMEM automatically in Pallas …
Browse files Browse the repository at this point in the history
…Mosaic GPU lowering

No estimation is done if `smem_scratch_bytes` was explicitly specified via
`compiler_params=`.

PiperOrigin-RevId: 672998660
  • Loading branch information
superbobry authored and jax authors committed Sep 10, 2024
1 parent 1b2ba9d commit 9fa0164
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
49 changes: 43 additions & 6 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,45 @@

partial = functools.partial

_smem_estimators = {}


def _regiter_smem_estimator(primitive: jax_core.Primitive):
def deco(fn):
_smem_estimators[primitive] = fn
return fn

return deco


def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int:
"""Estimates the amount of SMEM scratch bytes required by the kernel."""
max_used = 0
for eqn in jaxpr.eqns:
# TODO(slebedev): Add support for other primitives, notably control flow.
rule = _smem_estimators.get(eqn.primitive)
if rule is None:
# Assume that unsupported primitives are neutral wrt SMEM usage.
continue
max_used = max(
max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params)
)
return max_used


@_regiter_smem_estimator(primitives.run_scoped_p)
def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int:
del consts # Unused.
in_avals = (v.aval.inner_aval for v in jaxpr.invars)
return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals)


@_regiter_smem_estimator(lax.reduce_sum_p)
def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
if axes != (0,):
raise NotImplementedError("No support for axes other than 0 yet")
return 4 * x_aval.dtype.itemsize


@dataclasses.dataclass
class ModuleContext:
Expand Down Expand Up @@ -358,13 +397,11 @@ def _(step, _):

launch_ctx.await_async_copy(0)

# TODO(b/354568888): Add a jaxpr traversal to calculate the precise
# amount of memory required.
smem_scratch_bytes = compiler_params.get("smem_scratch_bytes"),
if smem_scratch_bytes is None:
smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr)
extra_smem_scratch = [
jax.ShapeDtypeStruct(
shape=[compiler_params.get("smem_scratch_bytes", 100000)],
dtype=np.int8,
)
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8)
]
module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel(
body,
Expand Down
1 change: 0 additions & 1 deletion tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def test_layer_norm(self, input_factor):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
compiler_params={"smem_scratch_bytes": 4 * 4},
)
def layer_norm(x_ref, o_ref):
x_mean = jnp.mean(x_ref[...])
Expand Down

0 comments on commit 9fa0164

Please sign in to comment.