Skip to content

Commit

Permalink
Small type improvements (#2055)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Aug 1, 2024
1 parent ab748d4 commit ddee9b4
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 11 deletions.
2 changes: 1 addition & 1 deletion modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Resolver:

def __init__(
self,
client=None,
client: _Client,
*,
environment_name: Optional[str] = None,
app_id: Optional[str] = None,
Expand Down
4 changes: 2 additions & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def registered_web_endpoints(self) -> List[str]:

def local_entrypoint(
self, _warn_parentheses_missing: Any = None, *, name: Optional[str] = None
) -> Callable[[Callable[..., Any]], None]:
) -> Callable[[Callable[..., Any]], _LocalEntrypoint]:
"""Decorate a function to be used as a CLI entrypoint for a Modal App.
These functions can be used to define code that runs locally to set up the app,
Expand Down Expand Up @@ -488,7 +488,7 @@ def main(foo: int, bar: str):
if name is not None and not isinstance(name, str):
raise InvalidError("Invalid value for `name`: Must be string.")

def wrapped(raw_f: Callable[..., Any]) -> None:
def wrapped(raw_f: Callable[..., Any]) -> _LocalEntrypoint:
info = FunctionInfo(raw_f)
tag = name if name is not None else raw_f.__qualname__
if tag in self._local_entrypoints:
Expand Down
3 changes: 2 additions & 1 deletion modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def __init__(
self._stub: Optional[api_grpc.ModalClientStub] = None

@property
def stub(self) -> Optional[api_grpc.ModalClientStub]:
def stub(self) -> api_grpc.ModalClientStub:
"""mdmd:hidden"""
assert self._stub
return self._stub

@property
Expand Down
2 changes: 1 addition & 1 deletion modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ async def _deploy(
namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
environment_name: Optional[str] = None,
client: Optional[_Client] = None,
) -> "_Mount":
) -> None:
check_object_name(deployment_name, "Mount")
self._deployment_name = deployment_name
self._namespace = namespace
Expand Down
2 changes: 1 addition & 1 deletion modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def _disconnect(
client: _Client,
app_id: str,
reason: "api_pb2.AppDisconnectReason.ValueType",
exc_str: Optional[str] = None,
exc_str: str = "",
) -> None:
"""Tell the server the client has disconnected for this app. Terminates all running tasks
for ephemeral apps."""
Expand Down
2 changes: 1 addition & 1 deletion modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def new():
Please use `Volume.from_name` (for persisted) or `Volume.ephemeral` (for ephemeral) volumes.
"""
deprecation_error((2024, 3, 20), Volume.new.__doc__)
deprecation_error((2024, 3, 20), Volume.new.__doc__) # type: ignore

@staticmethod
def from_name(
Expand Down
9 changes: 9 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ def type_stubs(ctx):
# We only generate type stubs for modules that contain synchronicity wrapped types
from synchronicity.synchronizer import SYNCHRONIZER_ATTR

stubs_to_remove = []
for root, _, files in os.walk("modal"):
for file in files:
if file.endswith(".pyi"):
stubs_to_remove.append(os.path.abspath(os.path.join(root, file)))
for path in sorted(stubs_to_remove):
os.remove(path)
print(f"Removed {path}")

def find_modal_modules(root: str = "modal"):
modules = []
path = importlib.import_module(root).__path__
Expand Down
8 changes: 4 additions & 4 deletions test/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

@pytest.mark.flaky(max_runs=2)
@pytest.mark.asyncio
async def test_multi_resolve_sequential_loads_once():
resolver = Resolver(None, environment_name="", app_id=None)
async def test_multi_resolve_sequential_loads_once(client):
resolver = Resolver(client, environment_name="", app_id=None)

load_count = 0

Expand All @@ -35,8 +35,8 @@ async def _load(self: _DumbObject, resolver: Resolver, existing_object_id: Optio


@pytest.mark.asyncio
async def test_multi_resolve_concurrent_loads_once():
resolver = Resolver(None, environment_name="", app_id=None)
async def test_multi_resolve_concurrent_loads_once(client):
resolver = Resolver(client, environment_name="", app_id=None)

load_count = 0

Expand Down

0 comments on commit ddee9b4

Please sign in to comment.