diff --git a/modal/__init__.py b/modal/__init__.py index 0b9a3539a..1c5021839 100644 --- a/modal/__init__.py +++ b/modal/__init__.py @@ -22,7 +22,7 @@ from .image import Image from .mount import Mount from .network_file_system import NetworkFileSystem - from .partial_function import asgi_app, build, enter, exit, method, web_endpoint, web_server, wsgi_app + from .partial_function import asgi_app, batch, build, enter, exit, method, web_endpoint, web_server, wsgi_app from .proxy import Proxy from .queue import Queue from .retries import Retries @@ -64,6 +64,7 @@ "Tunnel", "Volume", "asgi_app", + "batch", "build", "current_function_call_id", "current_input_id", diff --git a/modal/app.py b/modal/app.py index c9058198a..f6efca1e2 100644 --- a/modal/app.py +++ b/modal/app.py @@ -33,7 +33,12 @@ from .mount import _Mount from .network_file_system import _NetworkFileSystem from .object import _Object -from .partial_function import _find_callables_for_cls, _PartialFunction, _PartialFunctionFlags +from .partial_function import ( + _find_callables_for_cls, + _find_partial_methods_for_user_cls, + _PartialFunction, + _PartialFunctionFlags, +) from .proxy import _Proxy from .retries import Retries from .runner import _run_app @@ -589,13 +594,15 @@ def wrapped( ) if isinstance(f, _PartialFunction): - # typically for @function-wrapped @web_endpoint and @asgi_app + # typically for @function-wrapped @web_endpoint, @asgi_app, or @batch f.wrapped = True info = FunctionInfo(f.raw_f, serialized=serialized, name_override=name) raw_f = f.raw_f webhook_config = f.webhook_config is_generator = f.is_generator keep_warm = f.keep_warm or keep_warm + batch_max_size = f.batch_max_size + batch_linger_ms = f.batch_linger_ms if webhook_config and interactive: raise InvalidError("interactive=True is not supported with web endpoint functions") @@ -631,6 +638,8 @@ def f(self, x): info = FunctionInfo(f, serialized=serialized, name_override=name) webhook_config = None + batch_max_size = None + batch_linger_ms = None raw_f = f if info.function_name.endswith(".app"): @@ -668,6 +677,8 @@ def f(self, x): retries=retries, concurrency_limit=concurrency_limit, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, container_idle_timeout=container_idle_timeout, timeout=timeout, keep_warm=keep_warm, @@ -768,6 +779,21 @@ def wrapper(user_cls: CLS_T) -> _Cls: raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together") scheduler_placement = SchedulerPlacement(region=region) + batch_functions = _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.BATCH) + if batch_functions: + if len(batch_functions) > 1: + raise InvalidError(f"Modal class {user_cls.__name__} can only have one batch function.") + if len(_find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.FUNCTION)) > 1: + raise InvalidError( + f"Modal class {user_cls.__name__} with a modal batch function cannot have other modal methods." + ) + batch_function = next(iter(batch_functions.values())) + batch_max_size = batch_function.batch_max_size + batch_linger_ms = batch_function.batch_linger_ms + else: + batch_max_size = None + batch_linger_ms = None + cls_func = _Function.from_args( info, app=self, @@ -785,6 +811,8 @@ def wrapper(user_cls: CLS_T) -> _Cls: retries=retries, concurrency_limit=concurrency_limit, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, container_idle_timeout=container_idle_timeout, timeout=timeout, cpu=cpu, diff --git a/modal/functions.py b/modal/functions.py index a01a44ada..89423cee7 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -344,6 +344,8 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing is_method=True, use_function_id=class_service_function.object_id, use_method_name=method_name, + batch_max_size=partial_function.batch_max_size or 0, + batch_linger_ms=partial_function.batch_linger_ms or 0, ) assert resolver.app_id request = api_pb2.FunctionCreateRequest( @@ -484,6 +486,8 @@ def from_args( timeout: Optional[int] = None, concurrency_limit: Optional[int] = None, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[int] = None, container_idle_timeout: Optional[int] = None, cpu: Optional[float] = None, keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1 @@ -652,6 +656,8 @@ def from_args( ) else: raise InvalidError("Webhooks cannot be generators") + if is_generator and batch_max_size: + raise InvalidError("Batch functions cannot return generators") if container_idle_timeout is not None and container_idle_timeout <= 0: raise InvalidError("`container_idle_timeout` must be > 0") @@ -808,6 +814,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona app_name=app_name, is_builder_function=is_builder_function, allow_concurrent_inputs=allow_concurrent_inputs or 0, + batch_max_size=batch_max_size or 0, + batch_linger_ms=batch_linger_ms or 0, worker_id=config.get("worker_id"), is_auto_snapshot=is_auto_snapshot, is_method=bool(info.user_cls) and not info.is_service_class(), diff --git a/modal/partial_function.py b/modal/partial_function.py index ff5d107b2..2b1f65ff4 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -20,6 +20,9 @@ from .exception import InvalidError, deprecation_error, deprecation_warning from .functions import _Function +MAX_BATCH_SIZE = 1000 +MAX_BATCH_LINGER_MS = 10 * 60 * 1000 # 10 minutes + class _PartialFunctionFlags(enum.IntFlag): FUNCTION: int = 1 @@ -27,6 +30,7 @@ class _PartialFunctionFlags(enum.IntFlag): ENTER_PRE_SNAPSHOT: int = 4 ENTER_POST_SNAPSHOT: int = 8 EXIT: int = 16 + BATCH: int = 32 @staticmethod def all() -> "_PartialFunctionFlags": @@ -34,13 +38,15 @@ def all() -> "_PartialFunctionFlags": class _PartialFunction: - """Intermediate function, produced by @method or @web_endpoint""" + """Intermediate function, produced by @method, @web_endpoint, or @batch""" raw_f: Callable[..., Any] flags: _PartialFunctionFlags webhook_config: Optional[api_pb2.WebhookConfig] is_generator: Optional[bool] keep_warm: Optional[int] + batch_max_size: Optional[int] + batch_linger_ms: Optional[int] def __init__( self, @@ -49,6 +55,8 @@ def __init__( webhook_config: Optional[api_pb2.WebhookConfig] = None, is_generator: Optional[bool] = None, keep_warm: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[int] = None, ): self.raw_f = raw_f self.flags = flags @@ -56,6 +64,8 @@ def __init__( self.is_generator = is_generator self.keep_warm = keep_warm self.wrapped = False # Make sure that this was converted into a FunctionHandle + self.batch_max_size = batch_max_size + self.batch_linger_ms = batch_linger_ms def __get__(self, obj, objtype=None) -> _Function: k = self.raw_f.__name__ @@ -93,6 +103,8 @@ def add_flags(self, flags) -> "_PartialFunction": flags=(self.flags | flags), webhook_config=self.webhook_config, keep_warm=self.keep_warm, + batch_max_size=self.batch_max_size, + batch_linger_ms=self.batch_linger_ms, ) @@ -190,12 +202,18 @@ def f(self): def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: nonlocal is_generator - if isinstance(raw_f, _PartialFunction) and raw_f.webhook_config: + if isinstance(raw_f, _PartialFunction) and (raw_f.webhook_config): raw_f.wrapped = True # suppress later warning raise InvalidError( "Web endpoints on classes should not be wrapped by `@method`. " "Suggestion: remove the `@method` decorator." ) + if isinstance(raw_f, _PartialFunction) and (raw_f.batch_max_size is not None): + raw_f.wrapped = True # suppress later warning + raise InvalidError( + "Batch function on classes should not be wrapped by `@method`. " + "Suggestion: remove the `@method` decorator." + ) if is_generator is None: is_generator = inspect.isgeneratorfunction(raw_f) or inspect.isasyncgenfunction(raw_f) return _PartialFunction(raw_f, _PartialFunctionFlags.FUNCTION, is_generator=is_generator, keep_warm=keep_warm) @@ -554,6 +572,40 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction: return wrapper +def _batch( + _warn_parentheses_missing=None, + *, + batch_max_size: int, + batch_linger_ms: int, +) -> Callable[[Callable[..., Any]], _PartialFunction]: + if _warn_parentheses_missing: + raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batch()`.") + if batch_max_size < 1: + raise InvalidError("batch_max_size must be a positive integer.") + if batch_max_size >= MAX_BATCH_SIZE: + raise InvalidError(f"batch_max_size must be less than {MAX_BATCH_SIZE}.") + if batch_linger_ms < 0: + raise InvalidError("batch_linger_ms must be a non-negative integer.") + if batch_linger_ms >= MAX_BATCH_LINGER_MS: + raise InvalidError(f"batch_linger_ms must be less than {MAX_BATCH_LINGER_MS}.") + + def wrapper(raw_f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: + if isinstance(raw_f, _Function): + raw_f = raw_f.get_raw_f() + raise InvalidError( + f"Applying decorators for {raw_f} in the wrong order!\nUsage:\n\n" + "@app.function()\n@modal.batch()\ndef batched_function():\n ..." + ) + return _PartialFunction( + raw_f, + _PartialFunctionFlags.FUNCTION | _PartialFunctionFlags.BATCH, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) + + return wrapper + + method = synchronize_api(_method) web_endpoint = synchronize_api(_web_endpoint) asgi_app = synchronize_api(_asgi_app) @@ -562,3 +614,4 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction: build = synchronize_api(_build) enter = synchronize_api(_enter) exit = synchronize_api(_exit) +batch = synchronize_api(_batch) diff --git a/test/cls_test.py b/test/cls_test.py index 566b51b0d..28a35a6f3 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -867,3 +867,34 @@ def test_disabled_parameterized_snap_cls(): app.cls(enable_memory_snapshot=True)(ParameterizedClass2) app.cls(enable_memory_snapshot=True)(ParameterizedClass3) + + +app_batch = App() + + +def test_batch_method_duplicate_error(client): + with pytest.raises( + InvalidError, match="Modal class BatchClass_1 with a modal batch function cannot have other modal methods." + ): + + @app_batch.cls(serialized=True) + class BatchClass_1: + @modal.method() + def method(self): + pass + + @modal.batch(batch_max_size=2, batch_linger_ms=0) + def batch_method(self): + pass + + with pytest.raises(InvalidError, match="Modal class BatchClass_2 can only have one batch function."): + + @app_batch.cls(serialized=True) + class BatchClass_2: + @modal.batch(batch_max_size=2, batch_linger_ms=0) + def batch_method_1(self): + pass + + @modal.batch(batch_max_size=2, batch_linger_ms=0) + def batch_method_2(self): + pass diff --git a/test/container_test.py b/test/container_test.py index a6b2702ec..b43539edb 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -64,6 +64,44 @@ def _get_inputs( return [api_pb2.FunctionGetInputsResponse(inputs=[x]) for x in inputs] +def _get_inputs_batch( + args_list: List[Tuple[Tuple, Dict]], + batch_max_size: int, + kill_switch=True, + method_name: Optional[str] = None, +): + input_pbs = [ + api_pb2.FunctionInput( + args=serialize(args), data_format=api_pb2.DATA_FORMAT_PICKLE, method_name=method_name or "" + ) + for args in args_list + ] + inputs = [ + *( + api_pb2.FunctionGetInputsItem(input_id=f"in-xyz{i}", function_call_id="fc-123", input=input_pb) + for i, input_pb in enumerate(input_pbs) + ), + *([api_pb2.FunctionGetInputsItem(kill_switch=True)] if kill_switch else []), + ] + response_list = [] + current_batch: List[Any] = [] + while inputs: + input = inputs.pop(0) + if input.kill_switch: + if len(current_batch) > 0: + response_list.append(api_pb2.FunctionGetInputsResponse(inputs=current_batch)) + current_batch = [input] + break + if len(current_batch) > batch_max_size: + response_list.append(api_pb2.FunctionGetInputsResponse(inputs=current_batch)) + current_batch = [] + current_batch.append(input) + + if len(current_batch) > 0: + response_list.append(api_pb2.FunctionGetInputsResponse(inputs=current_batch)) + return response_list + + def _get_multi_inputs(args: List[Tuple[str, Tuple, Dict]] = []) -> List[api_pb2.FunctionGetInputsResponse]: responses = [] for input_n, (method_name, input_args, input_kwargs) in enumerate(args): @@ -113,6 +151,8 @@ def _container_args( app_name: str = "", is_builder_function: bool = False, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[int] = None, serialized_params: Optional[bytes] = None, is_checkpointing_function: bool = False, deps: List[str] = ["im-1"], @@ -144,6 +184,8 @@ def _container_args( is_builder_function=is_builder_function, is_auto_snapshot=is_auto_snapshot, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, is_checkpointing_function=is_checkpointing_function, object_dependencies=[api_pb2.ObjectDependency(object_id=object_id) for object_id in deps], max_inputs=max_inputs, @@ -180,6 +222,8 @@ def _run_container( app_name: str = "", is_builder_function: bool = False, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: int = 0, + batch_linger_ms: int = 0, serialized_params: Optional[bytes] = None, is_checkpointing_function: bool = False, deps: List[str] = ["im-1"], @@ -200,6 +244,8 @@ def _run_container( app_name, is_builder_function, allow_concurrent_inputs, + batch_max_size, + batch_linger_ms, serialized_params, is_checkpointing_function, deps, @@ -269,6 +315,16 @@ def _unwrap_scalar(ret: ContainerResult): return deserialize(ret.items[0].result.data, ret.client) +def _unwrap_batch_scalar(ret: ContainerResult, batch_size): + assert len(ret.items) == batch_size + outputs = [] + for item in ret.items: + assert item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + outputs.append(deserialize(item.result.data, ret.client)) + assert len(outputs) == batch_size + return outputs + + def _unwrap_exception(ret: ContainerResult): assert len(ret.items) == 1 assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE @@ -276,6 +332,17 @@ def _unwrap_exception(ret: ContainerResult): return ret.items[0].result.exception +def _unwrap_batch_exception(ret: ContainerResult, batch_size): + assert len(ret.items) == batch_size + outputs = [] + for item in ret.items: + assert item.result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE + assert "Traceback" in item.result.traceback + outputs.append(item.result.exception) + assert len(outputs) == batch_size + return outputs + + def _unwrap_generator(ret: ContainerResult) -> Tuple[List[Any], Optional[Exception]]: assert len(ret.items) == 1 item = ret.items[0] @@ -1021,6 +1088,97 @@ def test_concurrent_inputs_async_function(servicer): assert function_call_id and function_call_id == outputs[i - 1][2] +def _batch_function_test_helper(batch_func, servicer, args_list, expected_outputs, expected_status="success"): + batch_max_size = 4 + batch_linger_ms = 500 + inputs = _get_inputs_batch(args_list, batch_max_size) + + ret = _run_container( + servicer, + "test.supports.functions", + batch_func, + inputs=inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) + if expected_status == "success": + outputs = _unwrap_batch_scalar(ret, len(expected_outputs)) + else: + outputs = _unwrap_batch_exception(ret, len(expected_outputs)) + assert outputs == expected_outputs + + +@skip_github_non_linux +def test_batch_sync_function_full_batch(servicer): + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}) for _ in range(4)] + expected_outputs = [2] * 4 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) + + +@skip_github_non_linux +def test_batch_sync_function_partial_batch(servicer): + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}) for _ in range(2)] + expected_outputs = [2] * 2 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) + + +@skip_github_non_linux +def test_batch_sync_function_keyword_args(servicer): + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10,), {"y": 5}) for _ in range(4)] + expected_outputs = [2] * 4 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) + + +@skip_github_non_linux +def test_batch_sync_function_inputs_outputs_error(servicer): + # argument length does not match + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}), ((10, 5, 1), {})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_sync", servicer, inputs, []) + assert "Modal batch function batch_function_sync takes 2 positional arguments, but one call has 3." in str(err) + + # Unexpected keyword arg + inputs = [((10, 5), {}), ((10,), {"z": 5})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_sync", servicer, inputs, []) + assert "Modal batch function batch_function_sync got an unexpected keyword argument z in one call." in str(err) + + # Multiple values with keyword arg + inputs = [((10, 5), {}), ((10,), {"x": 1})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_sync", servicer, inputs, []) + assert "Modal batch function batch_function_sync got multiple values for argument x in one call." in str(err) + + # output must be list + inputs = [((10, 5), {})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_outputs_not_list", servicer, inputs, []) + assert "Output of batch function batch_function_outputs_not_list must be a list." in str(err) + + # outputs must match length of inputs + inputs = [((10, 5), {})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_outputs_wrong_len", servicer, inputs, []) + assert ( + "Output of batch function batch_function_outputs_wrong_len must be a list of the same length as its inputs." + in str(err) + ) + + +@skip_github_non_linux +def test_batch_sync_function_generic_error(servicer): + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 0), {}) for _ in range(4)] + expected_ouputs = ["ZeroDivisionError('division by zero')"] * 4 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_ouputs, expected_status="failure") + + +@skip_github_non_linux +def test_batch_async_function(servicer): + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}) for _ in range(4)] + expected_outputs = [2] * 4 + _batch_function_test_helper("batch_function_async", servicer, inputs, expected_outputs) + + @skip_github_non_linux def test_unassociated_function(servicer): ret = _run_container(servicer, "test.supports.functions", "unassociated_function") diff --git a/test/decorator_test.py b/test/decorator_test.py index a837a5202..0db1c7494 100644 --- a/test/decorator_test.py +++ b/test/decorator_test.py @@ -1,7 +1,7 @@ # Copyright Modal Labs 2023 import pytest -from modal import App, asgi_app, method, web_endpoint, wsgi_app +from modal import App, asgi_app, batch, method, web_endpoint, wsgi_app from modal.exception import InvalidError @@ -83,3 +83,16 @@ class Container: @web_endpoint() def generate(self): pass + + +def test_batch_method(): + app = App() + + with pytest.raises(InvalidError, match="remove the `@method`"): + + @app.cls() + class Container: + @method() # type: ignore + @batch(batch_max_size=2, batch_linger_ms=0) + def generate(self): + pass diff --git a/test/function_test.py b/test/function_test.py index 0fee2eaed..59db9d675 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -10,7 +10,7 @@ from synchronicity.exceptions import UserCodeException import modal -from modal import App, Image, Mount, NetworkFileSystem, Proxy, web_endpoint +from modal import App, Image, Mount, NetworkFileSystem, Proxy, batch, web_endpoint from modal._utils.async_utils import synchronize_api from modal._vendor import cloudpickle from modal.exception import ExecutionError, InvalidError @@ -802,3 +802,26 @@ def test_function_decorator_on_method(): with pytest.raises(InvalidError, match="@app.cls"): app.function()(X.f) + + +def test_batch_function_invalid_error(): + app = App() + + with pytest.raises(InvalidError, match="must be a positive integer"): + app.function(batch(batch_max_size=0, batch_linger_ms=1))(dummy) + + with pytest.raises(InvalidError, match="must be a non-negative integer"): + app.function(batch(batch_max_size=1, batch_linger_ms=-1))(dummy) + + with pytest.raises(InvalidError, match="must be less than"): + app.function(batch(batch_max_size=1000, batch_linger_ms=1))(dummy) + + with pytest.raises(InvalidError, match="must be less than"): + app.function(batch(batch_max_size=1, batch_linger_ms=10 * 60 * 1000))(dummy) + + with pytest.raises(InvalidError, match="cannot return generators"): + + @app.function(serialized=True) + @batch(batch_max_size=1, batch_linger_ms=1) + def f(x): + yield [x_i**2 for x_i in x] diff --git a/test/supports/functions.py b/test/supports/functions.py index 4fcc9618d..dba51eab8 100644 --- a/test/supports/functions.py +++ b/test/supports/functions.py @@ -3,12 +3,13 @@ import asyncio import time -from typing import List +from typing import List, Tuple from modal import ( App, Sandbox, asgi_app, + batch, build, current_function_call_id, current_input_id, @@ -330,6 +331,37 @@ async def sleep_700_async(x): return x * x, current_input_id(), current_function_call_id() +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +def batch_function_sync(x: Tuple[int], y: Tuple[int]): + outputs = [] + for x_i, y_i in zip(x, y): + outputs.append(x_i / y_i) + return outputs + + +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +def batch_function_outputs_not_list(x: Tuple[int], y: Tuple[int]): + return str(x) + + +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +def batch_function_outputs_wrong_len(x: Tuple[int], y: Tuple[int]): + return list(x) + [0] + + +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +async def batch_function_async(x: Tuple[int], y: Tuple[int]): + outputs = [] + for x_i, y_i in zip(x, y): + outputs.append(x_i / y_i) + await asyncio.sleep(0.1) + return outputs + + def unassociated_function(x): return 100 - x