From 782f90f0b132ce28db9e30cf4a619f5c51dc2b71 Mon Sep 17 00:00:00 2001 From: Deven Navani Date: Mon, 5 Aug 2024 12:17:02 -0400 Subject: [PATCH] [MOD-3431] Add default value field to ClassParameterSpec (#2072) --- modal/_utils/function_utils.py | 21 ++++++++++++++------- modal_proto/api.proto | 5 +++++ test/cls_test.py | 6 +++--- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 8752e9a40..5046ddc94 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -4,7 +4,7 @@ import os from enum import Enum from pathlib import Path, PurePosixPath -from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Type +from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Tuple, Type from grpclib import GRPCError from grpclib.exceptions import StreamTerminatedError @@ -28,6 +28,12 @@ class FunctionInfoType(Enum): NOTEBOOK = "notebook" +CLASS_PARAM_TYPE_MAP: Dict[Type, Tuple["api_pb2.ParameterType.ValueType", str]] = { + str: (api_pb2.PARAM_TYPE_STRING, "string_default"), + int: (api_pb2.PARAM_TYPE_INT, "int_default"), +} + + class LocalFunctionError(InvalidError): """Raised if a function declared in a non-global scope is used in an impermissible way""" @@ -238,13 +244,14 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo: modal_parameters: List[api_pb2.ClassParameterSpec] = [] signature = inspect.signature(self.user_cls) for param in signature.parameters.values(): - if param.annotation == str: - param_type = api_pb2.PARAM_TYPE_STRING - elif param.annotation == int: - param_type = api_pb2.PARAM_TYPE_INT - else: + has_default = param.default is not param.empty + if param.annotation not in CLASS_PARAM_TYPE_MAP: raise InvalidError("Strict class parameters need to be explicitly annotated as str or int") - modal_parameters.append(api_pb2.ClassParameterSpec(name=param.name, type=param_type)) + param_type, default_field = CLASS_PARAM_TYPE_MAP[param.annotation] + class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=param_type) + if has_default: + setattr(class_param_spec, default_field, param.default) + modal_parameters.append(class_param_spec) return api_pb2.ClassParameterInfo( format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 4e65d3a8b..157556212 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -585,6 +585,11 @@ message ClassParameterSet { message ClassParameterSpec { string name = 1; ParameterType type = 2; + bool has_default = 3; + oneof default_oneof { + string string_default = 4; + int64 int_default = 5; + } } message ClassParameterValue { diff --git a/test/cls_test.py b/test/cls_test.py index 566b51b0d..35e40a1c9 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -797,7 +797,7 @@ def test_cls_strict_parameters_added_to_definition(client, servicer, monkeypatch @strict_param_cls_app.cls(serialized=True) class StrictParamCls: - def __init__(self, x: str, y: int): + def __init__(self, x: str, y: int = 20): pass deploy_app(strict_param_cls_app, "my-cls-app", client=client) @@ -808,8 +808,8 @@ def __init__(self, x: str, y: int): assert definition.class_parameter_info == api_pb2.ClassParameterInfo( format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=[ - api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING), - api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT), + api_pb2.ClassParameterSpec(name="x", type=api_pb2.PARAM_TYPE_STRING, has_default=False), + api_pb2.ClassParameterSpec(name="y", type=api_pb2.PARAM_TYPE_INT, has_default=True, int_default=20), ], )