Skip to content

Commit

Permalink
pl.debug_print no longer restricts values to be scalars
Browse files Browse the repository at this point in the history
This allows printing arrays on Triton and soon on Mosaic GPU.

PiperOrigin-RevId: 675935666
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 18, 2024
1 parent 988ed2b commit b904599
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
18 changes: 13 additions & 5 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->

## Released with jax 0.4.34

* Changes

* {func}`jax.experimental.pallas.debug_print` no longer requires all arguments
to be scalars. The restrictions on the arguments are backend-specific:
Non-scalar arguments are currently only supported on GPU, when using Triton.

## Released with jax 0.4.33 (September 16, 2024)

## Released with jax 0.4.32 (September 11, 2024)

## Released with jax 0.4.32

* Changes
Expand All @@ -19,7 +31,7 @@ Remember to align the itemized text with the first line of an item within a list

* Deprecations

* New functionality:
* New functionality
* Improved error messages for mistakes in the signature of the index map functions,
to include the name and source location of the index map.

Expand Down Expand Up @@ -73,7 +85,3 @@ Remember to align the itemized text with the first line of an item within a list
* Added checkify support for {func}`jax.experimental.pallas.pallas_call` in
interpret mode ({jax-issue}`#21862`).
* Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`).




3 changes: 3 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2737,6 +2737,9 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int):
def _debug_print_rule(
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
):
if any(aval.shape for aval in ctx.avals_in):
raise NotImplementedError("Only scalar values are supported")

primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,9 @@ def _debug_print_lowering_rule(
fmt,
has_placeholders: bool,
):
del ctx
del has_placeholders
del has_placeholders # Unused.
if any(aval.shape for aval in ctx.avals_in):
raise NotImplementedError("Only scalar values are supported")
primitives.check_debug_print_format(fmt, *args)
mgpu.debug_print(fmt, *args)
return ()
Expand Down
12 changes: 5 additions & 7 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ class PrintEffect(effects.Effect):


def debug_print(fmt: str, *args: jax.typing.ArrayLike):
"""Prints scalar values from inside a Pallas kernel.
"""Prints values from inside a Pallas kernel.
Args:
fmt: A format string to be included in the output. The restrictions on the
Expand All @@ -724,11 +724,11 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike):
(``{...}``), since it is always printed before any of the values.
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported.
conversions are not supported. All values must be scalars.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string.
*args: The scalar values to print.
the format string. All values must be scalars.
*args: The values to print.
""" # fmt: skip
has_placeholders = False
if fmt:
Expand Down Expand Up @@ -771,9 +771,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool):

@debug_print_p.def_effectful_abstract_eval
def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool):
del fmt, has_placeholders
if any(aval.shape for aval in avals):
raise ValueError("Only scalar values are supported")
del avals, fmt, has_placeholders # Unused.
return [], {debug_print_effect}


Expand Down
9 changes: 8 additions & 1 deletion jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,14 @@ def debug_print_lowering_rule(
"pl.debug_print() does not support placeholders when lowering to Triton"
)

tt_dialect.print_(f" {fmt} ", hex=False, args=args)
tt_dialect.print_(
f" {fmt} ",
hex=False,
args=args,
is_signed=ir.DenseI32ArrayAttr.get([
jnp.issubdtype(aval.dtype, jnp.signedinteger) for aval in ctx.avals_in
]),
)
return ()


Expand Down

0 comments on commit b904599

Please sign in to comment.