Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOD-3431] Add default value field to ClassParameterSpec #2072

Merged
merged 8 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from enum import Enum
from pathlib import Path, PurePosixPath
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Type
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type

from grpclib import GRPCError
from grpclib.exceptions import StreamTerminatedError
Expand All @@ -28,6 +28,12 @@ class FunctionInfoType(Enum):
NOTEBOOK = "notebook"


CLASS_PARAM_TYPE_MAP: Dict[Type, Tuple["api_pb2.ParameterType.ValueType", str]] = {
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
}


class LocalFunctionError(InvalidError):
"""Raised if a function declared in a non-global scope is used in an impermissible way"""

Expand Down Expand Up @@ -238,13 +244,14 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
modal_parameters: List[api_pb2.ClassParameterSpec] = []
signature = inspect.signature(self.user_cls)
for param in signature.parameters.values():
if param.annotation == str:
param_type = api_pb2.PARAM_TYPE_STRING
elif param.annotation == int:
param_type = api_pb2.PARAM_TYPE_INT
else:
has_default = param.default is not param.empty
if param.annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError("Strict class parameters need to be explicitly annotated as str or int")
modal_parameters.append(api_pb2.ClassParameterSpec(name=param.name, type=param_type))
param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
setattr(class_param_spec, default_field, param.default)
modal_parameters.append(class_param_spec)

return api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
Expand Down
5 changes: 5 additions & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ message ClassParameterSet {
message ClassParameterSpec {
string name = 1;
ParameterType type = 2;
bool has_default = 3;
oneof default_oneof {
string string_default = 4;
int64 int_default = 5;
}
}

message ClassParameterValue {
Expand Down
6 changes: 3 additions & 3 deletions test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def test_cls_strict_parameters_added_to_definition(client, servicer, monkeypatch

@strict_param_cls_app.cls(serialized=True)
class StrictParamCls:
def __init__(self, x: str, y: int):
def __init__(self, x: str, y: int = 20):
pass

deploy_app(strict_param_cls_app, "my-cls-app", client=client)
Expand All @@ -808,8 +808,8 @@ def __init__(self, x: str, y: int):
assert definition.class_parameter_info == api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
schema=[
api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING),
api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT),
api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING, has_default=False),
api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT, has_default=True, int_default=20),
],
)

Expand Down
Loading