Skip to content

Commit

Permalink
Improve typing of jax.jit
Browse files Browse the repository at this point in the history
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 18, 2024
1 parent e903369 commit 2035a15
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
8 changes: 5 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import inspect
import math
import typing
from typing import (Any, Literal, NamedTuple, TypeVar, overload,
from typing import (Any, Literal, NamedTuple, ParamSpec, TypeVar, overload,
cast)
import weakref

Expand Down Expand Up @@ -137,9 +137,11 @@ def _update_debug_special_thread_local(_):

float0 = dtypes.float0

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True)

def jit(
fun: Callable,
fun: Callable[_P, _OutT],
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
Expand All @@ -151,7 +153,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
) -> pjit.JitWrapped:
) -> pjit.JitWrapped[_P, _OutT]:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import operator as op
import weakref
from typing import NamedTuple, Any, Union, cast
from typing import NamedTuple, Any, ParamSpec, TypeVar, Union, cast
import threading
import warnings

Expand Down Expand Up @@ -807,8 +807,10 @@ def ax_leaf(l):
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True)

class JitWrapped(stages.Wrapped):
class JitWrapped(stages.Wrapped[_P, _OutT]):

def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
Expand Down
12 changes: 7 additions & 5 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import functools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
from typing import Any, NamedTuple, ParamSpec, Protocol, TypeVar, Union, runtime_checkable

import jax

Expand Down Expand Up @@ -751,9 +751,11 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
raise ValueError(msg) from None
return Lowered(lowering, self.args_info, self._out_tree)

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True)

@runtime_checkable
class Wrapped(Protocol):
class Wrapped(Protocol[_P, _OutT]):
"""A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
Expand All @@ -762,11 +764,11 @@ class Wrapped(Protocol):
to compilation, and the result compiled prior to execution.
"""

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _OutT:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

def trace(self, *args, **kwargs) -> Traced:
def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced:
"""Trace this function explicitly for the given arguments.
A traced function is staged out of Python and translated to a jaxpr. It is
Expand All @@ -777,7 +779,7 @@ def trace(self, *args, **kwargs) -> Traced:
"""
raise NotImplementedError

def lower(self, *args, **kwargs) -> Lowered:
def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered:
"""Lower this function explicitly for the given arguments.
A lowered function is staged out of Python and translated to a
Expand Down

0 comments on commit 2035a15

Please sign in to comment.