Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing of jax.jit #23720

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _update_debug_special_thread_local(_):


def jit(
fun: Callable,
fun: Callable[stages._P, stages._OutT],
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
Expand All @@ -151,7 +151,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
) -> pjit.JitWrapped:
) -> pjit.JitWrapped[stages._P, stages._OutT]:
"""Sets up ``fun`` for just-in-time compilation with XLA.

Args:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike,
trans: int = 0) -> Array:
"""LU solve with broadcasting."""
return _lu_solve(lu, permutation, b, trans)
return _lu_solve(lu, permutation, b, trans) # type: ignore[arg-type]


# QR decomposition
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None
Refer to :func:`jax.numpy.cumprod` for the full documentation.
"""
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out)
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type]

def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array:
"""Return the cumulative sum of the array.
Refer to :func:`jax.numpy.cumsum` for the full documentation.
"""
return reductions.cumsum(self, axis=axis, dtype=dtype, out=out)
return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type]

def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array:
"""Return the specified diagonal from the array.
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
[6, 5]]], dtype=int32)
"""
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
return _flip(asarray(m), reductions._ensure_optional_axes(axis)) # type: ignore[arg-type]

@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
Expand Down Expand Up @@ -1982,7 +1982,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
Array([0, 1, 2], dtype=int32)
"""
util.check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) # type: ignore[arg-type]

@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
Expand Down Expand Up @@ -7259,7 +7259,7 @@ def delete(
obj = asarray(obj).ravel()
obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis])
obj = sort(obj)
obj -= arange(len(obj)) # type: ignore[arg-type,operator]
obj -= arange(len(obj))
i = arange(a.shape[axis] - obj.size)
i += (i[None, :] >= obj[:, None]).sum(0)
return a[(slice(None),) * axis + (i,)]
Expand Down Expand Up @@ -11034,8 +11034,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
kwds: dict[str, str] = {} if method is None else {'method': method}
return where(
bins_arr[-1] >= bins_arr[0],
searchsorted(bins_arr, x, side=side, **kwds),
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds)
searchsorted(bins_arr, x, side=side, **kwds), # type: ignore[arg-type]
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) # type: ignore[arg-type]
)


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An
if where is not None:
raise NotImplementedError(f"where argument of {self}")
call = self.__static_props['call'] or self._call_vectorized
return call(*args)
return call(*args) # type: ignore[arg-type]

@partial(jax.jit, static_argnames=['self'])
def _call_vectorized(self, *args):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def ax_leaf(l):
return broadcast_prefix(abstracted_axes, args, ax_leaf)


class JitWrapped(stages.Wrapped):
class JitWrapped(stages.Wrapped[stages._P, stages._OutT]):

def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple
raise ValueError("ndim different than 1 or 2 are not supported")
dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
code = jnp.argmin(dist, axis=-1)
dist_min = vmap(operator.getitem)(dist, code)
dist_min = vmap(operator.getitem)(dist, code) # type: ignore[call-overload]
return code, dist_min
4 changes: 2 additions & 2 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Ar
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ...

@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
@partial(jit, static_argnames=('full_matrices', 'compute_uv')) # type: ignore[misc]
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]:
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
Expand Down Expand Up @@ -545,7 +545,7 @@ def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
if output not in ('real', 'complex'):
raise ValueError(
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
return _schur(a, output)
return _schur(a, output) # type: ignore[arg-type]


def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ def sph_harm(m: Array,
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
'be statically specified to use `sph_harm` within JAX transformations.')

return _sph_harm(m, n, theta, phi, n_max)
return _sph_harm(m, n, theta, phi, n_max) # type: ignore[arg-type]


# exponential integrals
Expand Down
12 changes: 7 additions & 5 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import functools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
from typing import Any, NamedTuple, ParamSpec, Protocol, TypeVar, Union, runtime_checkable

import jax

Expand Down Expand Up @@ -751,9 +751,11 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
raise ValueError(msg) from None
return Lowered(lowering, self.args_info, self._out_tree)

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True) # pytype: disable=not-supported-yet

@runtime_checkable
class Wrapped(Protocol):
class Wrapped(Protocol[_P, _OutT]):
"""A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
Expand All @@ -762,11 +764,11 @@ class Wrapped(Protocol):
to compilation, and the result compiled prior to execution.
"""

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _OutT:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

def trace(self, *args, **kwargs) -> Traced:
def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced:
"""Trace this function explicitly for the given arguments.
A traced function is staged out of Python and translated to a jaxpr. It is
Expand All @@ -777,7 +779,7 @@ def trace(self, *args, **kwargs) -> Traced:
"""
raise NotImplementedError

def lower(self, *args, **kwargs) -> Lowered:
def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered:
"""Lower this function explicitly for the given arguments.
A lowered function is staged out of Python and translated to a
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/third_party/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]:
c_large = c_inf
s_large = s_inf
else:
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi
s_large = 0.5 - 1 / (jnp.pi * x) * cospi
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi # type: ignore[assignment]
s_large = 0.5 - 1 / (jnp.pi * x) * cospi # type: ignore[assignment]

# Other x values
t = jnp.pi * x2
Expand Down