Skip to content

Commit

Permalink
Merge branch 'main' into deven/update_minimum_cpu_core_request
Browse files Browse the repository at this point in the history
  • Loading branch information
devennavani committed Aug 5, 2024
2 parents 3a44d32 + 782f90f commit 60811d9
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ async def handle_input_exception(self, input_id, started_at: float) -> AsyncGene
# just skip creating any output for this input and keep going with the next instead
# it should have been marked as cancelled already in the backend at this point so it
# won't be retried
logger.warning(f"The current input ({input_id=}) was cancelled by a user request")
logger.warning(f"Received a cancellation signal while processing input {input_id}")
await self.complete_call(started_at)
return
except BaseException as exc:
Expand Down
21 changes: 14 additions & 7 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from enum import Enum
from pathlib import Path, PurePosixPath
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Type
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type

from grpclib import GRPCError
from grpclib.exceptions import StreamTerminatedError
Expand All @@ -28,6 +28,12 @@ class FunctionInfoType(Enum):
NOTEBOOK = "notebook"


CLASS_PARAM_TYPE_MAP: Dict[Type, Tuple["api_pb2.ParameterType.ValueType", str]] = {
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
}


class LocalFunctionError(InvalidError):
"""Raised if a function declared in a non-global scope is used in an impermissible way"""

Expand Down Expand Up @@ -238,13 +244,14 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
modal_parameters: List[api_pb2.ClassParameterSpec] = []
signature = inspect.signature(self.user_cls)
for param in signature.parameters.values():
if param.annotation == str:
param_type = api_pb2.PARAM_TYPE_STRING
elif param.annotation == int:
param_type = api_pb2.PARAM_TYPE_INT
else:
has_default = param.default is not param.empty
if param.annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError("Strict class parameters need to be explicitly annotated as str or int")
modal_parameters.append(api_pb2.ClassParameterSpec(name=param.name, type=param_type))
param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
setattr(class_param_spec, default_field, param.default)
modal_parameters.append(class_param_spec)

return api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
Expand Down
5 changes: 5 additions & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ message ClassParameterSet {
message ClassParameterSpec {
string name = 1;
ParameterType type = 2;
bool has_default = 3;
oneof default_oneof {
string string_default = 4;
int64 int_default = 5;
}
}

message ClassParameterValue {
Expand Down
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2024

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 4 # git: 497c8e6
build_number = 5 # git: 34b31cd
6 changes: 3 additions & 3 deletions test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def test_cls_strict_parameters_added_to_definition(client, servicer, monkeypatch

@strict_param_cls_app.cls(serialized=True)
class StrictParamCls:
def __init__(self, x: str, y: int):
def __init__(self, x: str, y: int = 20):
pass

deploy_app(strict_param_cls_app, "my-cls-app", client=client)
Expand All @@ -808,8 +808,8 @@ def __init__(self, x: str, y: int):
assert definition.class_parameter_info == api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
schema=[
api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING),
api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT),
api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING, has_default=False),
api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT, has_default=True, int_default=20),
],
)

Expand Down
2 changes: 1 addition & 1 deletion test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def test_cancellation_aborts_current_input_on_match(
api_pb2.ContainerHeartbeatResponse(cancel_input_event=api_pb2.CancelInputEvent(input_ids=cancelled_input_ids))
)
stdout, stderr = container_process.communicate()
assert stderr.decode().count("was cancelled by a user request") == live_cancellations
assert stderr.decode().count("Received a cancellation signal") == live_cancellations
assert "Traceback" not in stderr.decode()
assert container_process.returncode == 0 # wait for container to exit
duration = time.monotonic() - t0 # time from heartbeat to container exit
Expand Down

0 comments on commit 60811d9

Please sign in to comment.