Skip to content

Commit

Permalink
[MOD-3431] Add default value field to ClassParameterSpec (#2072)
Browse files Browse the repository at this point in the history
  • Loading branch information
devennavani committed Aug 5, 2024
1 parent 0e0dac5 commit 782f90f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
21 changes: 14 additions & 7 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from enum import Enum
from pathlib import Path, PurePosixPath
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Type
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type

from grpclib import GRPCError
from grpclib.exceptions import StreamTerminatedError
Expand All @@ -28,6 +28,12 @@ class FunctionInfoType(Enum):
NOTEBOOK = "notebook"


CLASS_PARAM_TYPE_MAP: Dict[Type, Tuple["api_pb2.ParameterType.ValueType", str]] = {
str: (api_pb2.PARAM_TYPE_STRING, "string_default"),
int: (api_pb2.PARAM_TYPE_INT, "int_default"),
}


class LocalFunctionError(InvalidError):
"""Raised if a function declared in a non-global scope is used in an impermissible way"""

Expand Down Expand Up @@ -238,13 +244,14 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
modal_parameters: List[api_pb2.ClassParameterSpec] = []
signature = inspect.signature(self.user_cls)
for param in signature.parameters.values():
if param.annotation == str:
param_type = api_pb2.PARAM_TYPE_STRING
elif param.annotation == int:
param_type = api_pb2.PARAM_TYPE_INT
else:
has_default = param.default is not param.empty
if param.annotation not in CLASS_PARAM_TYPE_MAP:
raise InvalidError("Strict class parameters need to be explicitly annotated as str or int")
modal_parameters.append(api_pb2.ClassParameterSpec(name=param.name, type=param_type))
param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation]
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type)
if has_default:
setattr(class_param_spec, default_field, param.default)
modal_parameters.append(class_param_spec)

return api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
Expand Down
5 changes: 5 additions & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ message ClassParameterSet {
message ClassParameterSpec {
string name = 1;
ParameterType type = 2;
bool has_default = 3;
oneof default_oneof {
string string_default = 4;
int64 int_default = 5;
}
}

message ClassParameterValue {
Expand Down
6 changes: 3 additions & 3 deletions test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def test_cls_strict_parameters_added_to_definition(client, servicer, monkeypatch

@strict_param_cls_app.cls(serialized=True)
class StrictParamCls:
def __init__(self, x: str, y: int):
def __init__(self, x: str, y: int = 20):
pass

deploy_app(strict_param_cls_app, "my-cls-app", client=client)
Expand All @@ -808,8 +808,8 @@ def __init__(self, x: str, y: int):
assert definition.class_parameter_info == api_pb2.ClassParameterInfo(
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
schema=[
api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING),
api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT),
api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING, has_default=False),
api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT, has_default=True, int_default=20),
],
)

Expand Down

0 comments on commit 782f90f

Please sign in to comment.