Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing of jax.jit #23720

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

lebrice
Copy link

@lebrice lebrice commented Sep 18, 2024

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Thanks for the contribution! See #14688 for a past attempt at this, the things that went wrong, and some discussion of the considerations around a change like this. In particular, PyType was a blocker in the past, and we'll have to check whether that's still the case.

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

Thanks @jakevdp for the prompt response.
After this change, the MyPy type checker in the pre-commit hook now becomes aware of typing issues that were already present in the codebase. Would it normally be expected for this PR to also include fixes for these typing issues as well? Or should that be done in a separate PR?

@superbobry
Copy link
Collaborator

Yeah, newly discovered issues need to be fixed or silenced, because otherwise merging the PR will break our CI.

Could you given an example of the issue you discovered, please?

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

I just rebased the PR to only improve the typing of jax.jit. I'll make another PR to add type hints to jax.eval_shapes, JitWrapped.eval_shapes and JitWrapped.trace.

Here are the typing issues that the MyPy pre-commit hook now sees within the codebase, now that the signature of the callable isn't dropped by jax.jit:

jax/_src/numpy/ufunc_api.py:177: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "*tuple[Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex, ...]"; expected "ufunc"  [arg-type]
jax/_src/numpy/array_methods.py:171: error: Argument "axis" to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | None"  [arg-type]
jax/_src/numpy/array_methods.py:179: error: Argument "axis" to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | None"  [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/lax/linalg.py:1665: error: Argument 3 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/numpy/lax_numpy.py:1127: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "int | Sequence[int] | None"; expected "int | tuple[int, ...] | None"  [arg-type]
jax/_src/numpy/lax_numpy.py:1985: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "tuple[int, ...] | None"; expected "tuple[int, ...]"  [arg-type]
jax/_src/numpy/lax_numpy.py:7228: error: Unused "type: ignore[arg-type, operator]" comment  [unused-ignore]
jax/_src/third_party/scipy/special.py:275: error: Incompatible types in assignment (expression has type "Array", variable has type "float")  [assignment]
jax/_src/third_party/scipy/special.py:276: error: Incompatible types in assignment (expression has type "Array", variable has type "float")  [assignment]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 1  [misc]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 2  [misc]
jax/_src/scipy/linalg.py:222: error: Overloaded function implementation does not accept all possible arguments of signature 3  [misc]
jax/_src/scipy/linalg.py:548: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]
jax/_src/scipy/cluster/vq.py:73: error: No overload variant of "getitem" matches argument types "Array", "Array"  [call-overload]
jax/_src/scipy/cluster/vq.py:73: note: Possible overload variants:
jax/_src/scipy/cluster/vq.py:73: note:     def [_T] getitem(Sequence[_T], slice, /) -> Sequence[_T]
jax/_src/scipy/cluster/vq.py:73: note:     def [_K, _V] getitem(SupportsGetItem[_K, _V], _K, /) -> _V
jax/_src/scipy/special.py:1808: error: Argument 5 to "__call__" of "Wrapped" has incompatible type "int | None"; expected "int"  [arg-type]
Found 17 errors in 8 files (checked 511 source files)

What is the typical way that such typing errors are silenced in this project? Do you prefer # type: ignore comments or typing.cast ?
Also, should I leave a comment referencing a new issue for the typing errors that I silence?

@yashk2810
Copy link
Collaborator

I am pretty sure this will break a lot of other targets internally too making this change very difficult to land.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Yeah, #14688 was eventually blocked by the fact that pytype doesn't properly suppot ParamSpec. I'm not sure whether that's changed in the meantime.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Some of the errors reflect that the jit annotation in this PR is not correct. For example this one:

jax/_src/lax/linalg.py:1665: error: Argument 1 to "__call__" of "Wrapped" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | bool | int | float | complex"; expected "Array"  [arg-type]

It basically comes from something that looks like this:

@jit
def _lu_solve(x: Array):
  ...

def lu_solve(x: ArrayLike):
  return _lu_solve(x)  # <- type error, because ArrayLike is not Array

However, when you wrap a function with jit, all ArrayLike inputs are implicitly converted to Array before being passed to the wrapped function. So in some senses this annotation is correct, and the mypy error is a false-positive due to the new jit annotation being stricter than it needs to be.

What do you think?

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

re: @superbobry, @yashk2810 - I silenced the new typing errors that mypy raised in the pre-commit hook.

re: @jakevdp

However, when you wrap a function with jit, all ArrayLike inputs are implicitly converted to Array before being passed to the wrapped function. So in some senses this annotation is correct, and the mypy error is a false-positive due to the new jit annotation being stricter than it needs to be.

I agree. In my view, this only encourages internal jax source to be more explicit, by not depending on this implicit conversion from ArrayLike to Array.

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

re: @yashk2810

I am pretty sure this will break a lot of other targets internally too making this change very difficult to land.

Are you saying that this tiny little PR would also improve other downstream projects at Google? 🤩 😛

@yashk2810
Copy link
Collaborator

Haha, I wouldn't say improve but it will break a lot of stuff and I don't think we have the bandwidth to fix all those projects. Hence landing this is very hard IRL.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

I think, setting aside caveats about how hard this might be to land, this is generally a change we want, and one we've been hoping to add for a long time. Initially we were blocked by the lack of ParamSpec in the type system, then we were blocked by the lack of support for ParamSpec in the mypy implementation, and then in the pytype implementation. If the pytype blocker is now fixed, we can do the work to land this (basically adding # ignore statements in any place that it breaks). But I think pytype may still be a blocker, as it was for #14688 six months ago.

The relevant issue is google/pytype#1471, which is still open.

@jakevdp jakevdp self-assigned this Sep 18, 2024
@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

re: @yashk2810
I guess it's a matter of perspective. In my view, revealing typing errors / encouraging code to be more explicit is an improvement.

re: @yashk2810 @jakevdp
I'd be very happy to help and put in the time required in order to fix (or at the very least silence) any such typing errors that are revealed as a result of this PR in other projects. Do you by chance have a kind of list of these public-facing, jax-based projects that could potentially have their CI fail as a result of this change? If not, I can also try to gather such a list myself.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Just to be clear, the pytype blocker is not about revealing existing errors, it's about the fact that pytype fails loudly and completely when it sees covariant=True. If that hasn't changed, then I'm afraid we can't do much else here.

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

With respect to google/pytype#1471, would this change be easier to merge if the output TypeVar were not marked as covariant? 🤔

Edit: I'll double-check, but I think that the output var being invariant might cause other issues (which would then be due to the the annotation not being 100% correct)

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Yes, if this didn't use covariant typevars, it would be easier to merge. But my understanding from #14688 was that covariant typevars are required in order to correctly annotate jit.

@superbobry
Copy link
Collaborator

@jakevdp pytype treats all type variables as covariant IIRC, so maybe we can just suppress the warning for that particular type var?

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

Re @superbobry @jakevdp :
I added some # pytype: disable=not-supported-yet over the typevar definitions. If my understanding of pytype is correct, it will now simply drop the covariant arg, and treat those as regular typevars.

@lebrice
Copy link
Author

lebrice commented Sep 18, 2024

@lebrice : I'd be very happy to help and put in the time required in order to fix (or at the very least silence) any such typing errors that are revealed as a result of this PR in other projects. Do you by chance have a kind of list of these public-facing, jax-based projects that could potentially have their CI fail as a result of this change? If not, I can also try to gather such a list myself.

I'll start with Flax, since that seems like the most obvious downstream project from my perspective.
I'm able to get their Pytype-related CI steps to run without error, at least locally.

@lebrice lebrice force-pushed the improve-jit-typing branch 2 times, most recently from e3faefb to be5fef9 Compare September 19, 2024 22:04
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Sep 20, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 20, 2024

Pulling in to run internal pytype tests

jax/_src/api.py Outdated
@@ -151,7 +153,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
) -> pjit.JitWrapped:
) -> pjit.JitWrapped[_P, _OutT]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing an error here in pytype that's blocking any further testing:

/jax/_src/api.py:156: error: in <module>: class JitWrapped is not indexable [not-indexable]
  ('JitWrapped' does not subclass Generic)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JitWrapped does in fact inherit from Generic, so I suspect there's something deeper going on that's preventing pytype from seeing that within the package structure of bazel-based builds.

@lebrice lebrice force-pushed the improve-jit-typing branch 2 times, most recently from 5c11738 to ee069b3 Compare September 21, 2024 17:36
jax/_src/api.py Outdated
@@ -137,9 +137,11 @@ def _update_debug_special_thread_local(_):

float0 = dtypes.float0

_P = ParamSpec("_P")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse this from stages instead of redefining in 3 files.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Addressed in 32de804 (rebased to keep everything in a single commit)

jax/_src/pjit.py Outdated
@@ -807,8 +807,10 @@ def ax_leaf(l):
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)

_P = ParamSpec("_P")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse this from stages instead of redefining in 3 files.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Addressed in 32de804 (rebased to keep everything in a single commit)

@lebrice lebrice force-pushed the improve-jit-typing branch 2 times, most recently from 630c5f7 to 32de804 Compare September 21, 2024 17:43
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
@lebrice
Copy link
Author

lebrice commented Sep 21, 2024

I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a Co-authored-by: Fabrice Normandin <normandf@mila.quebec> in the commit message?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 21, 2024

I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a Co-authored-by: Fabrice Normandin <normandf@mila.quebec> in the commit message?

If this is merged, your actual unmodified commit would be added to the JAX source tree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants