Skip to content

Commit

Permalink
Improve typing of jax.jit
Browse files Browse the repository at this point in the history
- Fix for #23719

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 21, 2024
1 parent a2b3919 commit 630c5f7
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 25 deletions.
8 changes: 5 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import inspect
import math
import typing
from typing import (Any, Literal, NamedTuple, TypeVar, overload,
from typing import (Any, Literal, NamedTuple, ParamSpec, TypeVar, overload,
cast)
import weakref

Expand Down Expand Up @@ -137,9 +137,11 @@ def _update_debug_special_thread_local(_):

float0 = dtypes.float0

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True) # pytype: disable=not-supported-yet

def jit(
fun: Callable,
fun: Callable[_P, _OutT],
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
Expand All @@ -151,7 +153,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
) -> pjit.JitWrapped:
) -> pjit.JitWrapped[_P, _OutT]:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike,
trans: int = 0) -> Array:
"""LU solve with broadcasting."""
return _lu_solve(lu, permutation, b, trans)
return _lu_solve(lu, permutation, b, trans) # type: ignore[arg-type]


# QR decomposition
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None
Refer to :func:`jax.numpy.cumprod` for the full documentation.
"""
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out)
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type]

def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array:
"""Return the cumulative sum of the array.
Refer to :func:`jax.numpy.cumsum` for the full documentation.
"""
return reductions.cumsum(self, axis=axis, dtype=dtype, out=out)
return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type]

def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array:
"""Return the specified diagonal from the array.
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
[6, 5]]], dtype=int32)
"""
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
return _flip(asarray(m), reductions._ensure_optional_axes(axis)) # type: ignore[arg-type]

@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
Expand Down Expand Up @@ -1982,7 +1982,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
Array([0, 1, 2], dtype=int32)
"""
util.check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) # type: ignore[arg-type]

@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
Expand Down Expand Up @@ -7259,7 +7259,7 @@ def delete(
obj = asarray(obj).ravel()
obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis])
obj = sort(obj)
obj -= arange(len(obj)) # type: ignore[arg-type,operator]
obj -= arange(len(obj))
i = arange(a.shape[axis] - obj.size)
i += (i[None, :] >= obj[:, None]).sum(0)
return a[(slice(None),) * axis + (i,)]
Expand Down Expand Up @@ -11034,8 +11034,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
kwds: dict[str, str] = {} if method is None else {'method': method}
return where(
bins_arr[-1] >= bins_arr[0],
searchsorted(bins_arr, x, side=side, **kwds),
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds)
searchsorted(bins_arr, x, side=side, **kwds), # type: ignore[arg-type]
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) # type: ignore[arg-type]
)


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An
if where is not None:
raise NotImplementedError(f"where argument of {self}")
call = self.__static_props['call'] or self._call_vectorized
return call(*args)
return call(*args) # type: ignore[arg-type]

@partial(jax.jit, static_argnames=['self'])
def _call_vectorized(self, *args):
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import operator as op
import weakref
from typing import NamedTuple, Any, Union, cast
from typing import NamedTuple, Any, ParamSpec, TypeVar, Union, cast
import threading
import warnings

Expand Down Expand Up @@ -807,8 +807,10 @@ def ax_leaf(l):
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True) # pytype: disable=not-supported-yet

class JitWrapped(stages.Wrapped):
class JitWrapped(stages.Wrapped[_P, _OutT]):

def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple
raise ValueError("ndim different than 1 or 2 are not supported")
dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
code = jnp.argmin(dist, axis=-1)
dist_min = vmap(operator.getitem)(dist, code)
dist_min = vmap(operator.getitem)(dist, code) # type: ignore[call-overload]
return code, dist_min
4 changes: 2 additions & 2 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Ar
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ...

@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
@partial(jit, static_argnames=('full_matrices', 'compute_uv')) # type: ignore[misc]
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]:
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
Expand Down Expand Up @@ -545,7 +545,7 @@ def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
if output not in ('real', 'complex'):
raise ValueError(
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
return _schur(a, output)
return _schur(a, output) # type: ignore[arg-type]


def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ def sph_harm(m: Array,
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
'be statically specified to use `sph_harm` within JAX transformations.')

return _sph_harm(m, n, theta, phi, n_max)
return _sph_harm(m, n, theta, phi, n_max) # type: ignore[arg-type]


# exponential integrals
Expand Down
12 changes: 7 additions & 5 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import functools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
from typing import Any, NamedTuple, ParamSpec, Protocol, TypeVar, Union, runtime_checkable

import jax

Expand Down Expand Up @@ -751,9 +751,11 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
raise ValueError(msg) from None
return Lowered(lowering, self.args_info, self._out_tree)

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True) # pytype: disable=not-supported-yet

@runtime_checkable
class Wrapped(Protocol):
class Wrapped(Protocol[_P, _OutT]):
"""A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
Expand All @@ -762,11 +764,11 @@ class Wrapped(Protocol):
to compilation, and the result compiled prior to execution.
"""

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _OutT:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

def trace(self, *args, **kwargs) -> Traced:
def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced:
"""Trace this function explicitly for the given arguments.
A traced function is staged out of Python and translated to a jaxpr. It is
Expand All @@ -777,7 +779,7 @@ def trace(self, *args, **kwargs) -> Traced:
"""
raise NotImplementedError

def lower(self, *args, **kwargs) -> Lowered:
def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered:
"""Lower this function explicitly for the given arguments.
A lowered function is staged out of Python and translated to a
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/third_party/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]:
c_large = c_inf
s_large = s_inf
else:
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi
s_large = 0.5 - 1 / (jnp.pi * x) * cospi
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi # type: ignore[assignment]
s_large = 0.5 - 1 / (jnp.pi * x) * cospi # type: ignore[assignment]

# Other x values
t = jnp.pi * x2
Expand Down

0 comments on commit 630c5f7

Please sign in to comment.