From ddee9b48ed5c1e711e702af8a0de1bcbd8c3dc53 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 1 Aug 2024 16:51:15 -0400 Subject: [PATCH] Small type improvements (#2055) --- modal/_resolver.py | 2 +- modal/app.py | 4 ++-- modal/client.py | 3 ++- modal/mount.py | 2 +- modal/runner.py | 2 +- modal/volume.py | 2 +- tasks.py | 9 +++++++++ test/resolver_test.py | 8 ++++---- 8 files changed, 21 insertions(+), 11 deletions(-) diff --git a/modal/_resolver.py b/modal/_resolver.py index 8fd69b3f8..b4a1bd94f 100644 --- a/modal/_resolver.py +++ b/modal/_resolver.py @@ -50,7 +50,7 @@ class Resolver: def __init__( self, - client=None, + client: _Client, *, environment_name: Optional[str] = None, app_id: Optional[str] = None, diff --git a/modal/app.py b/modal/app.py index c9058198a..0e1b70922 100644 --- a/modal/app.py +++ b/modal/app.py @@ -434,7 +434,7 @@ def registered_web_endpoints(self) -> List[str]: def local_entrypoint( self, _warn_parentheses_missing: Any = None, *, name: Optional[str] = None - ) -> Callable[[Callable[..., Any]], None]: + ) -> Callable[[Callable[..., Any]], _LocalEntrypoint]: """Decorate a function to be used as a CLI entrypoint for a Modal App. These functions can be used to define code that runs locally to set up the app, @@ -488,7 +488,7 @@ def main(foo: int, bar: str): if name is not None and not isinstance(name, str): raise InvalidError("Invalid value for `name`: Must be string.") - def wrapped(raw_f: Callable[..., Any]) -> None: + def wrapped(raw_f: Callable[..., Any]) -> _LocalEntrypoint: info = FunctionInfo(raw_f) tag = name if name is not None else raw_f.__qualname__ if tag in self._local_entrypoints: diff --git a/modal/client.py b/modal/client.py index 365f1eac9..b2c9737a0 100644 --- a/modal/client.py +++ b/modal/client.py @@ -102,8 +102,9 @@ def __init__( self._stub: Optional[api_grpc.ModalClientStub] = None @property - def stub(self) -> Optional[api_grpc.ModalClientStub]: + def stub(self) -> api_grpc.ModalClientStub: """mdmd:hidden""" + assert self._stub return self._stub @property diff --git a/modal/mount.py b/modal/mount.py index 4344e6b5a..ebf3dd967 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -598,7 +598,7 @@ async def _deploy( namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, environment_name: Optional[str] = None, client: Optional[_Client] = None, - ) -> "_Mount": + ) -> None: check_object_name(deployment_name, "Mount") self._deployment_name = deployment_name self._namespace = namespace diff --git a/modal/runner.py b/modal/runner.py index 116fa063d..39cb97da3 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -171,7 +171,7 @@ async def _disconnect( client: _Client, app_id: str, reason: "api_pb2.AppDisconnectReason.ValueType", - exc_str: Optional[str] = None, + exc_str: str = "", ) -> None: """Tell the server the client has disconnected for this app. Terminates all running tasks for ephemeral apps.""" diff --git a/modal/volume.py b/modal/volume.py index ddecd338d..86fdc49fa 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -144,7 +144,7 @@ def new(): Please use `Volume.from_name` (for persisted) or `Volume.ephemeral` (for ephemeral) volumes. """ - deprecation_error((2024, 3, 20), Volume.new.__doc__) + deprecation_error((2024, 3, 20), Volume.new.__doc__) # type: ignore @staticmethod def from_name( diff --git a/tasks.py b/tasks.py index 6a3856430..2119176fd 100644 --- a/tasks.py +++ b/tasks.py @@ -232,6 +232,15 @@ def type_stubs(ctx): # We only generate type stubs for modules that contain synchronicity wrapped types from synchronicity.synchronizer import SYNCHRONIZER_ATTR + stubs_to_remove = [] + for root, _, files in os.walk("modal"): + for file in files: + if file.endswith(".pyi"): + stubs_to_remove.append(os.path.abspath(os.path.join(root, file))) + for path in sorted(stubs_to_remove): + os.remove(path) + print(f"Removed {path}") + def find_modal_modules(root: str = "modal"): modules = [] path = importlib.import_module(root).__path__ diff --git a/test/resolver_test.py b/test/resolver_test.py index c458d65ed..6593e7c99 100644 --- a/test/resolver_test.py +++ b/test/resolver_test.py @@ -10,8 +10,8 @@ @pytest.mark.flaky(max_runs=2) @pytest.mark.asyncio -async def test_multi_resolve_sequential_loads_once(): - resolver = Resolver(None, environment_name="", app_id=None) +async def test_multi_resolve_sequential_loads_once(client): + resolver = Resolver(client, environment_name="", app_id=None) load_count = 0 @@ -35,8 +35,8 @@ async def _load(self: _DumbObject, resolver: Resolver, existing_object_id: Optio @pytest.mark.asyncio -async def test_multi_resolve_concurrent_loads_once(): - resolver = Resolver(None, environment_name="", app_id=None) +async def test_multi_resolve_concurrent_loads_once(client): + resolver = Resolver(client, environment_name="", app_id=None) load_count = 0