Skip to content

Commit

Permalink
Pallas GPU no longer falls back to lax.pow for integer powers
Browse files Browse the repository at this point in the history
Instead the lowering computes the power in a loop by squaring, similarly
to how we do it in the StableHLO lowering.

Fixes #21928.

PiperOrigin-RevId: 644313113
  • Loading branch information
superbobry authored and jax authors committed Jun 18, 2024
1 parent 5bfd6af commit dfcfb36
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
35 changes: 26 additions & 9 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,14 +1073,32 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]):
return _bcast_to(_ensure_ir_value(x, x_aval), shape)


def _integer_pow(a, *, y):
if y == 2:
return a * a
if y == 3:
return a * a * a
if y == -2:
return 1.0 / (a * a)
return jax.lax.pow(a, y)
@register_lowering(lax.integer_pow_p)
def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int):
if y == 0:
return _full(x.type, 1)

is_reciprocal = y < 0
if is_reciprocal:
y = -y

acc = None
while y > 0:
y, mod = divmod(y, 2)
if mod:
acc = x if acc is None else _mul(acc, x)
if y > 0:
x = _mul(x, x)
assert acc is not None

[x_aval] = ctx.avals_in
[out_aval] = ctx.avals_out
acc = _cast(acc, x_aval.dtype, out_aval.dtype)
if is_reciprocal:
signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger)
return _truediv(_full(acc.type, 1), acc, signed=signed)
else:
return acc


def lower_fun(
Expand All @@ -1100,7 +1118,6 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params):

_JAX_FN_MAPPING = {
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
lax.integer_pow_p: _integer_pow,
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
}

Expand Down
11 changes: 11 additions & 0 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,17 @@ def kernel(x_ref, y_ref, o_ref):
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))

@parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3)
def test_integer_pow(self, y):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[:] = lax.integer_pow(x_ref[...], y)

x = jnp.array([1, 2, 3, 4]).astype(jnp.float32) / 10
np.testing.assert_allclose(kernel(x), lax.integer_pow(x, y))

@parameterized.parameters("float32", "float64")
def test_nextafter(self, dtype):
if jtu.test_device_matches(["tpu"]) and dtype == "float64":
Expand Down

0 comments on commit dfcfb36

Please sign in to comment.