From dc664057b3137aa10243d65c0070b704eb2fd322 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 25 Jul 2024 15:45:08 +0000 Subject: [PATCH 1/8] initial --- modal/__init__.py | 3 +- modal/_container_entrypoint.py | 201 +++++++++++++++++++++++++++------ modal/_container_io_manager.py | 166 +++++++++++++++++++++------ modal/app.py | 8 +- modal/cls.py | 4 + modal/execution_context.py | 12 +- modal/functions.py | 9 ++ modal/partial_function.py | 57 +++++++++- test/container_test.py | 157 +++++++++++++++++++++++++ test/supports/functions.py | 20 ++++ 10 files changed, 558 insertions(+), 79 deletions(-) diff --git a/modal/__init__.py b/modal/__init__.py index 0b9a3539a..1c5021839 100644 --- a/modal/__init__.py +++ b/modal/__init__.py @@ -22,7 +22,7 @@ from .image import Image from .mount import Mount from .network_file_system import NetworkFileSystem - from .partial_function import asgi_app, build, enter, exit, method, web_endpoint, web_server, wsgi_app + from .partial_function import asgi_app, batch, build, enter, exit, method, web_endpoint, web_server, wsgi_app from .proxy import Proxy from .queue import Queue from .retries import Retries @@ -64,6 +64,7 @@ "Tunnel", "Volume", "asgi_app", + "batch", "build", "current_function_call_id", "current_input_id", diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 7bf844bcd..5525aa0b3 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -13,7 +13,7 @@ import typing from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from google.protobuf.message import Message from synchronicity import Interface @@ -107,6 +107,7 @@ class FinalizedFunction: is_async: bool is_generator: bool data_format: int # api_pb2.DataFormat + signature_info: "FuncSignatureInfo" class Service(metaclass=ABCMeta): @@ -128,6 +129,27 @@ def get_finalized_functions( ... +@dataclass +class FuncSignatureInfo: + func_name: str + params_and_defaults: List[Tuple[str, Any]] + + +def get_func_signature_info(func, ignore_self=False): + signature_params = list(inspect.signature(func).parameters.values()) + function_name = func.__name__ + + if ignore_self: + if len(signature_params) == 0: + raise ValueError( + "Methods must take a 'self' argument, but the " f"method '{func.__name__}' does not have one." + ) + signature_params = signature_params[1:] + + param_names_and_defaults = [(param.name, param.default) for param in signature_params] + return FuncSignatureInfo(function_name, param_names_and_defaults) + + @dataclass class ImportedFunction(Service): user_cls_instance: Any @@ -153,6 +175,7 @@ def get_finalized_functions( is_async=is_async, is_generator=is_generator, data_format=api_pb2.DATA_FORMAT_PICKLE, + signature_info=get_func_signature_info(self._user_defined_callable), ) } @@ -166,6 +189,7 @@ def get_finalized_functions( is_async=True, is_generator=True, data_format=api_pb2.DATA_FORMAT_ASGI, + signature_info=get_func_signature_info(self._user_defined_callable), ) } @@ -200,6 +224,7 @@ def get_finalized_functions( is_async=is_async, is_generator=is_generator, data_format=api_pb2.DATA_FORMAT_PICKLE, + signature_info=get_func_signature_info(bound_func, ignore_self=True), ) else: web_callable = construct_webhook_callable(bound_func, webhook_config, container_io_manager) @@ -208,6 +233,7 @@ def get_finalized_functions( is_async=True, is_generator=True, data_format=api_pb2.DATA_FORMAT_ASGI, + signature_info=get_func_signature_info(bound_func, ignore_self=True), ) finalized_functions[method_name] = finalized_function return finalized_functions @@ -323,30 +349,112 @@ def _sigint_handler(): self.loop.remove_signal_handler(signal.SIGINT) +def _aggregate_ids_and_args( + local_inputs: Union[LocalInput, List[LocalInput]], + container_io_manager: "modal._container_io_manager.ContainerIOManager", + signature_info: FuncSignatureInfo, +) -> Tuple[Union[str, List[str]], Union[str, List[str]], List[Any], Dict[str, Any]]: + if isinstance(local_inputs, list): + input_ids = [local_input.input_id for local_input in local_inputs] + function_call_ids = [local_input.function_call_id for local_input in local_inputs] + num_required_args = len( + [param for param in signature_info.params_and_defaults if param[1] == inspect.Parameter.empty] + ) + num_args = len(signature_info.params_and_defaults) + args_list, kwargs_list = zip( + *[ + container_io_manager.deserialize(local_input.input_args) if local_input.input_args else ((), {}) + for local_input in local_inputs + ] + ) + args_and_kwargs_dict = [{} for _ in input_ids] + for i, (args, kwargs) in enumerate(zip(args_list, kwargs_list)): + for j, arg in enumerate(args): + if j < len(signature_info.params_and_defaults): + args_and_kwargs_dict[i][signature_info.params_and_defaults[j][0]] = arg + else: + raise ( + InvalidError( + f"Batched function {signature_info.func_name} takes {num_required_args} \ + positional arguments but {len(args)} were given" + ) + if num_args == num_required_args + else InvalidError( + f"Batched function {signature_info.func_name} takes from {num_required_args} to {num_args} \ + positional arguments but {len(args)} were given" + ) + ) + for k, v in kwargs.items(): + if k in [param[0] for param in signature_info.params_and_defaults]: + if k in args_and_kwargs_dict[i]: + raise InvalidError( + f"Batched function {signature_info.func_name} got multiple values for argument {k}" + ) + args_and_kwargs_dict[i][k] = v + else: + raise InvalidError( + f"Batched function {signature_info.func_name} got an unexpected keyword argument: {k}" + ) + + formatted_args = [] + for arg_name, _ in signature_info.params_and_defaults[:num_required_args]: + if any(arg_name not in args_and_kwargs_dict[i] for i in range(len(input_ids))): + raise InvalidError( + f"Batched function {signature_info.func_name} missing required positional argument {arg_name}" + ) + formatted_args.append([args_and_kwargs_dict[i][arg_name] for i in range(len(input_ids))]) + + formatted_kwargs = { + kwarg_name: [args_and_kwargs_dict[j].get(kwarg_name, default) for j in range(len(input_ids))] + for kwarg_name, default in signature_info.params_and_defaults[num_required_args:] + } + + return input_ids, function_call_ids, formatted_args, formatted_kwargs + + else: + input_id = local_inputs.input_id + function_call_id = local_inputs.function_call_id + args, kwargs = ( + container_io_manager.deserialize(local_inputs.input_args) if local_inputs.input_args else ((), {}) + ) + return input_id, function_call_id, args, kwargs + + def call_function( user_code_event_loop: UserCodeEventLoop, container_io_manager: "modal._container_io_manager.ContainerIOManager", finalized_functions: Dict[str, FinalizedFunction], input_concurrency: int, + batch_max_size: Optional[int], + batch_linger_ms: Optional[int], ): - async def run_input_async(finalized_function: FinalizedFunction, local_input: LocalInput) -> None: + async def run_input_async( + finalized_function: FinalizedFunction, + local_inputs: Union[LocalInput, List[LocalInput]], + container_io_manager: "modal._container_io_manager.ContainerIOManager", + ) -> None: started_at = time.time() - reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id) - async with container_io_manager.handle_input_exception.aio(local_input.input_id, started_at): - logger.debug(f"Starting input {local_input.input_id} (async)") - res = finalized_function.callable(*local_input.args, **local_input.kwargs) - logger.debug(f"Finished input {local_input.input_id} (async)") + input_ids, function_call_ids, args, kwargs = _aggregate_ids_and_args( + local_inputs, container_io_manager, finalized_function.signature_info + ) + reset_context = _set_current_context_ids(input_ids, function_call_ids) + async with container_io_manager.handle_input_exception.aio(input_ids, started_at): + logger.debug(f"Starting input {input_ids} (async)") + res = finalized_function.callable(*args, **kwargs) + logger.debug(f"Finished input {input_ids} (async)") # TODO(erikbern): any exception below shouldn't be considered a user exception if finalized_function.is_generator: if not inspect.isasyncgen(res): raise InvalidError(f"Async generator function returned value of type {type(res)}") + if isinstance(input_ids, list) or isinstance(function_call_ids, list): + raise InvalidError("Batched functions cannot return generators.") # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024) generator_output_task = asyncio.create_task( container_io_manager.generator_output_task.aio( - local_input.function_call_id, + function_call_ids, finalized_function.data_format, generator_queue, ) @@ -361,7 +469,7 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo await generator_output_task # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) await container_io_manager.push_output.aio( - local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE + input_ids, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE ) else: if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): @@ -370,18 +478,23 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo " You might need to use @app.function(..., is_generator=True)." ) value = await res - await container_io_manager.push_output.aio( - local_input.input_id, started_at, value, finalized_function.data_format - ) + await container_io_manager.push_output.aio(input_ids, started_at, value, finalized_function.data_format) reset_context() - def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInput) -> None: + def run_input_sync( + finalized_function: FinalizedFunction, + local_inputs: Union[LocalInput, List[LocalInput]], + container_io_manager: "modal._container_io_manager.ContainerIOManager", + ) -> None: started_at = time.time() - reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id) - with container_io_manager.handle_input_exception(local_input.input_id, started_at): - logger.debug(f"Starting input {local_input.input_id} (sync)") - res = finalized_function.callable(*local_input.args, **local_input.kwargs) - logger.debug(f"Finished input {local_input.input_id} (sync)") + input_ids, function_call_ids, args, kwargs = _aggregate_ids_and_args( + local_inputs, container_io_manager, finalized_function.signature_info + ) + reset_context = _set_current_context_ids(input_ids, function_call_ids) + with container_io_manager.handle_input_exception(input_ids, started_at): + logger.debug(f"Starting input {input_ids} (sync)") + res = finalized_function.callable(*args, **kwargs) + logger.debug(f"Finished input {input_ids} (sync)") # TODO(erikbern): any exception below shouldn't be considered a user exception if finalized_function.is_generator: @@ -391,7 +504,7 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore - local_input.function_call_id, + function_call_ids, finalized_function.data_format, generator_queue, _future=True, # type: ignore # Synchronicity magic to return a future. @@ -405,18 +518,24 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) generator_output_task.result() # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) - container_io_manager.push_output( - local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE - ) + container_io_manager.push_output(input_ids, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) else: if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): raise InvalidError( f"Sync (non-generator) function return value of type {type(res)}." " You might need to use @app.function(..., is_generator=True)." ) - container_io_manager.push_output(local_input.input_id, started_at, res, finalized_function.data_format) + container_io_manager.push_output(input_ids, started_at, res, finalized_function.data_format) reset_context() + def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction: + if isinstance(local_inputs, list): + assert len(local_inputs) > 0 + assert all(local_input.method_name == local_inputs[0].method_name for local_input in local_inputs) + return finalized_functions[local_inputs[0].method_name] + else: + return finalized_functions[local_inputs.method_name] + if input_concurrency > 1: with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool: @@ -425,24 +544,28 @@ async def run_concurrent_inputs(): # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s # for them to resolve gracefully: async with TaskContext(0.01) as task_context: - async for local_input in container_io_manager.run_inputs_outputs.aio(input_concurrency): - finalized_function = finalized_functions[local_input.method_name] + async for local_inputs in container_io_manager.run_inputs_outputs.aio( + input_concurrency, batch_max_size, batch_linger_ms + ): + finalized_function = _get_finalized_functions(local_inputs) # Note that run_inputs_outputs will not return until the concurrency semaphore has # released all its slots so that they can be acquired by the run_inputs_outputs finalizer # This prevents leaving the task_context before outputs have been created # TODO: refactor to make this a bit more easy to follow? if finalized_function.is_async: - task_context.create_task(run_input_async(finalized_function, local_input)) + task_context.create_task( + run_input_async(finalized_function, local_inputs, container_io_manager) + ) else: # run sync input in thread - thread_pool.submit(run_input_sync, finalized_function, local_input) + thread_pool.submit(run_input_sync, finalized_function, local_inputs, container_io_manager) user_code_event_loop.run(run_concurrent_inputs()) else: - for local_input in container_io_manager.run_inputs_outputs(input_concurrency): - finalized_function = finalized_functions[local_input.method_name] + for local_inputs in container_io_manager.run_inputs_outputs(input_concurrency, batch_max_size, batch_linger_ms): + finalized_function = _get_finalized_functions(local_inputs) if finalized_function.is_async: - user_code_event_loop.run(run_input_async(finalized_function, local_input)) + user_code_event_loop.run(run_input_async(finalized_function, local_inputs, container_io_manager)) else: # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation # during function execution. This is sent to cancel inputs from the user @@ -453,7 +576,7 @@ def _cancel_input_signal_handler(signum, stackframe): # run this sync code in the main thread, blocking the "userland" event loop # this lets us cancel it using a signal handler that raises an exception try: - run_input_sync(finalized_function, local_input) + run_input_sync(finalized_function, local_inputs, container_io_manager) finally: signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler @@ -738,10 +861,14 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): # Container can fetch multiple inputs simultaneously if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL: - # Concurrency doesn't apply for `modal shell`. + # Concurrency and batching doesn't apply for `modal shell`. input_concurrency = 1 + batch_max_size = 0 + batch_linger_ms = 0 else: input_concurrency = function_def.allow_concurrent_inputs or 1 + batch_max_size = function_def.batch_max_size or 0 + batch_linger_ms = function_def.batch_linger_ms or 0 # Get ids and metadata for objects (primarily functions and classes) on the app container_app: RunningApp = container_io_manager.get_app_objects() @@ -779,7 +906,6 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): if function_def.is_checkpointing_function: container_io_manager.memory_snapshot() - # Install hooks for interactive functions. if function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED: @@ -804,7 +930,14 @@ def breakpoint_wrapper(): # Execute the function. try: - call_function(event_loop, container_io_manager, finalized_functions, input_concurrency) + call_function( + event_loop, + container_io_manager, + finalized_functions, + input_concurrency, + batch_max_size, + batch_linger_ms, + ) finally: # Run exit handlers. From this point onward, ignore all SIGINT signals that come from # graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index ded25a955..3d52c1f59 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -9,7 +9,7 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple +from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple, Union from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -47,8 +47,7 @@ class LocalInput: input_id: str function_call_id: str method_name: str - args: Any - kwargs: Any + input_args: Any class _ContainerIOManager: @@ -350,7 +349,9 @@ def get_max_inputs_to_fetch(self): return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6)) @synchronizer.no_io_translation - async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]: + async def _generate_inputs( + self, + ) -> AsyncIterator[Union[Tuple[str, str, api_pb2.FunctionInput], List[Tuple[str, str, api_pb2.FunctionInput]]]]: request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id) eof_received = False iteration = 0 @@ -358,6 +359,8 @@ async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.Functi request.average_call_time = self.get_average_call_time() request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove. request.input_concurrency = self._input_concurrency + request.batch_max_size = self._batch_max_size + request.batch_linger_ms = self._batch_linger_ms await self._semaphore.acquire() yielded = False @@ -375,11 +378,11 @@ async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.Functi ) await asyncio.sleep(response.rate_limit_sleep_duration) elif response.inputs: - # for input cancellations and concurrency logic we currently assume - # that there is no input buffering in the container - assert len(response.inputs) == 1 - - for item in response.inputs: + if self._batch_max_size == 0: + # for input cancellations and concurrency logic we currently assume + # that there is no input buffering in the container + assert len(response.inputs) == 1 + item = response.inputs[0] if item.kill_switch: logger.debug(f"Task {self.task_id} input kill signal input.") eof_received = True @@ -401,48 +404,133 @@ async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.Functi if item.input.final_input or self.function_def.max_inputs == 1: eof_received = True break + + else: + assert len(response.inputs) <= request.batch_max_size + + inputs_list = [] + for item in response.inputs: + if item.kill_switch: + assert len(response.inputs) == 1 + logger.debug(f"Task {self.task_id} input kill signal input.") + eof_received = True + break + assert item.input_id not in self.cancelled_input_ids + + # If we got a pointer to a blob, download it from S3. + if item.input.WhichOneof("args_oneof") == "args_blob_id": + input_pb = await self.populate_input_blobs(item.input) + else: + input_pb = item.input + + inputs_list.append((item.input_id, item.function_call_id, input_pb)) + if item.input.final_input: + eof_received = True + logger.error("Final input not expected in batched input stream") + break + + if not eof_received: + yield inputs_list + yielded = True + + # We only support max_inputs = 1 at the moment + if self.function_def.max_inputs == 1: + eof_received = True + finally: if not yielded: self._semaphore.release() @synchronizer.no_io_translation - async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[LocalInput]: + async def run_inputs_outputs( + self, + input_concurrency: int = 1, + batch_max_size: int = 0, + batch_linger_ms: int = 0, + ) -> AsyncIterator[Union[LocalInput, List[LocalInput]]]: # Ensure we do not fetch new inputs when container is too busy. # Before trying to fetch an input, acquire the semaphore: # - if no input is fetched, release the semaphore. # - or, when the output for the fetched input is sent, release the semaphore. self._input_concurrency = input_concurrency + self._batch_max_size = batch_max_size + self._batch_linger_ms = batch_linger_ms self._semaphore = asyncio.Semaphore(input_concurrency) - async for input_id, function_call_id, input_pb in self._generate_inputs(): - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) - self.current_input_id, self.current_input_started_at = (None, None) + if batch_max_size == 0: + async for input_id, function_call_id, input_pb in self._generate_inputs(): + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + yield LocalInput(input_id, function_call_id, input_pb.method_name, input_pb.args) + self.current_input_id, self.current_input_started_at = (None, None) + else: + async for inputs_list in self._generate_inputs(): + local_inputs_list = [] + for input_id, function_call_id, input_pb in inputs_list: + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + local_inputs_list.append( + LocalInput(input_id, function_call_id, input_pb.method_name, input_pb.args) + ) + self.current_input_id, self.current_input_started_at = (None, None) + yield local_inputs_list # collect all active input slots, meaning all inputs have wrapped up. for _ in range(input_concurrency): await self._semaphore.acquire() - async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs): - # upload data to S3 if too big. - if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: - data_blob_id = await blob_upload(kwargs["data"], self._client.stub) - # mutating kwargs. - del kwargs["data"] - kwargs["data_blob_id"] = data_blob_id - - output = api_pb2.FunctionPutOutputsItem( - input_id=input_id, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - + async def _push_output( + self, input_ids: Union[str, List[str]], started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs + ): + outputs = [] + if isinstance(input_ids, list): + data_list = [] + if "data" in kwargs and kwargs["data"]: + # split the list of data in kwargs to respective input_ids + data_list = self.deserialize_data_format(kwargs.pop("data"), data_format) + assert isinstance(data_list, list), "Output of batched function must be a list" + assert len(data_list) == len( + input_ids + ), "Output of batched function must be a list of the same length as its list of inputs." + for i, input_id in enumerate(input_ids): + data = self.serialize_data_format(data_list[i], data_format) if data_list else None + result = ( + ( + # upload data to S3 if too big. + api_pb2.GenericResult(data_blob_id=await blob_upload(data, self._client.stub), **kwargs) + if len(data) > MAX_OBJECT_SIZE_BYTES + else api_pb2.GenericResult(data=data, **kwargs) + ) + if data + else api_pb2.GenericResult(**kwargs) + ) + outputs.append( + api_pb2.FunctionPutOutputsItem( + input_id=input_id, + input_started_at=started_at, + output_created_at=time.time(), + result=result, + data_format=data_format, + ) + ) + else: + # upload data to S3 if too big. + if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: + data_blob_id = await blob_upload(kwargs["data"], self._client.stub) + # mutating kwargs. + del kwargs["data"] + kwargs["data_blob_id"] = data_blob_id + + outputs.append( + api_pb2.FunctionPutOutputsItem( + input_id=input_ids, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, + ) + ) await retry_transient_errors( self._client.stub.FunctionPutOutputs, - api_pb2.FunctionPutOutputsRequest(outputs=[output]), + api_pb2.FunctionPutOutputsRequest(outputs=outputs), additional_status_codes=[Status.RESOURCE_EXHAUSTED], max_retries=None, # Retry indefinitely, trying every 1s. ) @@ -502,7 +590,9 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: raise UserException() @asynccontextmanager - async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]: + async def handle_input_exception( + self, input_ids: Union[str, List[str]], started_at: float + ) -> AsyncGenerator[None, None]: """Handle an exception while processing a function input.""" try: yield @@ -517,7 +607,7 @@ async def handle_input_exception(self, input_id, started_at: float) -> AsyncGene # just skip creating any output for this input and keep going with the next instead # it should have been marked as cancelled already in the backend at this point so it # won't be retried - logger.warning(f"The current input ({input_id=}) was cancelled by a user request") + logger.warning(f"The current input ({input_ids=}) was cancelled by a user request") await self.complete_call(started_at) return except BaseException as exc: @@ -541,11 +631,15 @@ async def handle_input_exception(self, input_id, started_at: float) -> AsyncGene repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" await self._push_output( - input_id, + input_ids, started_at=started_at, data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, - data=self.serialize_exception(exc), + data=( + self.serialize_exception([exc] for _ in input_ids) + if isinstance(input_ids, list) + else self.serialize_exception(exc) + ), exception=repr_exc, traceback=traceback.format_exc(), serialized_tb=serialized_tb, diff --git a/modal/app.py b/modal/app.py index 47476fc19..d57ae295f 100644 --- a/modal/app.py +++ b/modal/app.py @@ -584,13 +584,15 @@ def wrapped( ) if isinstance(f, _PartialFunction): - # typically for @function-wrapped @web_endpoint and @asgi_app + # typically for @function-wrapped @web_endpoint, @asgi_app, or @batch f.wrapped = True info = FunctionInfo(f.raw_f, serialized=serialized, name_override=name) raw_f = f.raw_f webhook_config = f.webhook_config is_generator = f.is_generator keep_warm = f.keep_warm or keep_warm + batch_max_size = f.batch_max_size + batch_linger_ms = f.batch_linger_ms if webhook_config and interactive: raise InvalidError("interactive=True is not supported with web endpoint functions") @@ -626,6 +628,8 @@ def f(self, x): info = FunctionInfo(f, serialized=serialized, name_override=name) webhook_config = None + batch_max_size = None + batch_linger_ms = None raw_f = f if info.function_name.endswith(".app"): @@ -663,6 +667,8 @@ def f(self, x): retries=retries, concurrency_limit=concurrency_limit, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, container_idle_timeout=container_idle_timeout, timeout=timeout, keep_warm=keep_warm, diff --git a/modal/cls.py b/modal/cls.py index 52e025e17..13c05cb6f 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -359,6 +359,8 @@ def with_options( timeout: Optional[int] = None, concurrency_limit: Optional[int] = None, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[int] = None, container_idle_timeout: Optional[int] = None, allow_background_volume_commits: Optional[bool] = None, ) -> "_Cls": @@ -404,6 +406,8 @@ def with_options( replace_volume_mounts=replace_volume_mounts, volume_mounts=volume_mounts, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, ) return cls diff --git a/modal/execution_context.py b/modal/execution_context.py index 764d85cf8..e45cab899 100644 --- a/modal/execution_context.py +++ b/modal/execution_context.py @@ -1,6 +1,6 @@ # Copyright Modal Labs 2024 from contextvars import ContextVar -from typing import Callable, Optional +from typing import Callable, List, Optional, Union from modal._container_io_manager import _ContainerIOManager from modal._utils.async_utils import synchronize_api @@ -32,7 +32,7 @@ async def _interact() -> None: interact = synchronize_api(_interact) -def current_input_id() -> Optional[str]: +def current_input_id() -> Optional[Union[str, List[str]]]: """Returns the input ID for the current input. Can only be called from Modal function (i.e. in a container context). @@ -70,9 +70,11 @@ def process_stuff(): return None -def _set_current_context_ids(input_id: str, function_call_id: str) -> Callable[[], None]: - input_token = _current_input_id.set(input_id) - function_call_token = _current_function_call_id.set(function_call_id) +def _set_current_context_ids( + input_ids: Union[str, List[str]], function_call_ids: Union[str, List[str]] +) -> Callable[[], None]: + input_token = _current_input_id.set(input_ids) + function_call_token = _current_function_call_id.set(function_call_ids) def _reset_current_context_ids(): _current_input_id.reset(input_token) diff --git a/modal/functions.py b/modal/functions.py index 6fbf8dfb7..7d259f119 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -344,6 +344,8 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing is_method=True, use_function_id=class_service_function.object_id, use_method_name=method_name, + batch_max_size=partial_function.batch_max_size, + batch_linger_ms=partial_function.batch_linger_ms, ) assert resolver.app_id request = api_pb2.FunctionCreateRequest( @@ -484,6 +486,8 @@ def from_args( timeout: Optional[int] = None, concurrency_limit: Optional[int] = None, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[float] = None, container_idle_timeout: Optional[int] = None, cpu: Optional[float] = None, keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1 @@ -515,6 +519,7 @@ def from_args( assert info.user_cls assert not webhook_config assert not schedule + assert batch_max_size is None if secret is not None: deprecation_error( @@ -652,6 +657,8 @@ def from_args( ) else: raise InvalidError("Webhooks cannot be generators") + if is_generator and batch_max_size: + raise InvalidError("Batched functions cannot be generators") if container_idle_timeout is not None and container_idle_timeout <= 0: raise InvalidError("`container_idle_timeout` must be > 0") @@ -808,6 +815,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona app_name=app_name, is_builder_function=is_builder_function, allow_concurrent_inputs=allow_concurrent_inputs or 0, + batch_max_size=batch_max_size or 0, + batch_linger_ms=batch_linger_ms or 0, worker_id=config.get("worker_id"), is_auto_snapshot=is_auto_snapshot, is_method=bool(info.user_cls) and not info.is_service_class(), diff --git a/modal/partial_function.py b/modal/partial_function.py index ff5d107b2..f2a60785a 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -20,6 +20,9 @@ from .exception import InvalidError, deprecation_error, deprecation_warning from .functions import _Function +MAX_BATCH_SIZE = 49 +MAX_BATCH_LINGER_MS = 12000 + class _PartialFunctionFlags(enum.IntFlag): FUNCTION: int = 1 @@ -27,6 +30,7 @@ class _PartialFunctionFlags(enum.IntFlag): ENTER_PRE_SNAPSHOT: int = 4 ENTER_POST_SNAPSHOT: int = 8 EXIT: int = 16 + BATCH: int = 32 @staticmethod def all() -> "_PartialFunctionFlags": @@ -34,13 +38,15 @@ def all() -> "_PartialFunctionFlags": class _PartialFunction: - """Intermediate function, produced by @method or @web_endpoint""" + """Intermediate function, produced by @method, @web_endpoint, or @batch""" raw_f: Callable[..., Any] flags: _PartialFunctionFlags webhook_config: Optional[api_pb2.WebhookConfig] is_generator: Optional[bool] keep_warm: Optional[int] + batch_max_size: Optional[int] + batch_linger_ms: Optional[int] def __init__( self, @@ -49,6 +55,8 @@ def __init__( webhook_config: Optional[api_pb2.WebhookConfig] = None, is_generator: Optional[bool] = None, keep_warm: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[int] = None, ): self.raw_f = raw_f self.flags = flags @@ -56,6 +64,8 @@ def __init__( self.is_generator = is_generator self.keep_warm = keep_warm self.wrapped = False # Make sure that this was converted into a FunctionHandle + self.batch_max_size = batch_max_size + self.batch_linger_ms = batch_linger_ms def __get__(self, obj, objtype=None) -> _Function: k = self.raw_f.__name__ @@ -93,6 +103,19 @@ def add_flags(self, flags) -> "_PartialFunction": flags=(self.flags | flags), webhook_config=self.webhook_config, keep_warm=self.keep_warm, + batch_max_size=self.batch_max_size, + batch_linger_ms=self.batch_linger_ms, + ) + + def set_batch_params(self, batch_max_size, batch_linger_ms) -> "_PartialFunction": + self.wrapped = True + return _PartialFunction( + raw_f=self.raw_f, + flags=(self.flags | _PartialFunctionFlags.BATCH), + webhook_config=self.webhook_config, + keep_warm=self.keep_warm, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, ) @@ -190,7 +213,7 @@ def f(self): def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: nonlocal is_generator - if isinstance(raw_f, _PartialFunction) and raw_f.webhook_config: + if isinstance(raw_f, _PartialFunction) and (raw_f.webhook_config or (raw_f.batch_max_size is not None)): raw_f.wrapped = True # suppress later warning raise InvalidError( "Web endpoints on classes should not be wrapped by `@method`. " @@ -554,6 +577,35 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction: return wrapper +def _batch( + _warn_parentheses_missing=None, + *, + batch_max_size: int, + batch_linger_ms: int, +) -> Callable[[Callable[..., Any]], _PartialFunction]: + if _warn_parentheses_missing: + raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batch()`.") + if batch_max_size < 1: + raise InvalidError("batch_max_size must be a positive integer.") + if batch_max_size > MAX_BATCH_SIZE: + raise InvalidError(f"batch_max_size must be less than or equal to {MAX_BATCH_SIZE}.") + if batch_linger_ms < 0: + raise InvalidError("batch_linger_ms must be a non-negative integer.") + if batch_linger_ms > MAX_BATCH_LINGER_MS: + raise InvalidError(f"batch_linger_ms must be less than or equal to {MAX_BATCH_LINGER_MS}.") + + def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: + if isinstance(f, _PartialFunction): + _disallow_wrapping_method(f, "batch") + return f.set_batch_params(batch_max_size, batch_linger_ms) + else: + return _PartialFunction( + f, _PartialFunctionFlags.BATCH, batch_max_size=batch_max_size, batch_linger_ms=batch_linger_ms + ) + + return wrapper + + method = synchronize_api(_method) web_endpoint = synchronize_api(_web_endpoint) asgi_app = synchronize_api(_asgi_app) @@ -562,3 +614,4 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction: build = synchronize_api(_build) enter = synchronize_api(_enter) exit = synchronize_api(_exit) +batch = synchronize_api(_batch) diff --git a/test/container_test.py b/test/container_test.py index a6b2702ec..03661a3d0 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -64,6 +64,44 @@ def _get_inputs( return [api_pb2.FunctionGetInputsResponse(inputs=[x]) for x in inputs] +def _get_inputs_batched( + args_list: List[Tuple[Tuple, Dict]], + batch_max_size: int, + kill_switch=True, + method_name: Optional[str] = None, +): + input_pbs = [ + api_pb2.FunctionInput( + args=serialize(args), data_format=api_pb2.DATA_FORMAT_PICKLE, method_name=method_name or "" + ) + for args in args_list + ] + inputs = [ + *( + api_pb2.FunctionGetInputsItem(input_id=f"in-xyz{i}", function_call_id="fc-123", input=input_pb) + for i, input_pb in enumerate(input_pbs) + ), + *([api_pb2.FunctionGetInputsItem(kill_switch=True)] if kill_switch else []), + ] + response_list = [] + current_batch = [] + while inputs: + input = inputs.pop(0) + if input.kill_switch: + if len(current_batch) > 0: + response_list.append(api_pb2.FunctionGetInputsResponse(inputs=current_batch)) + current_batch = [input] + break + if len(current_batch) > batch_max_size: + response_list.append(api_pb2.FunctionGetInputsResponse(inputs=current_batch)) + current_batch = [] + current_batch.append(input) + + if len(current_batch) > 0: + response_list.append(api_pb2.FunctionGetInputsResponse(inputs=current_batch)) + return response_list + + def _get_multi_inputs(args: List[Tuple[str, Tuple, Dict]] = []) -> List[api_pb2.FunctionGetInputsResponse]: responses = [] for input_n, (method_name, input_args, input_kwargs) in enumerate(args): @@ -113,6 +151,8 @@ def _container_args( app_name: str = "", is_builder_function: bool = False, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: Optional[int] = None, + batch_linger_ms: Optional[int] = None, serialized_params: Optional[bytes] = None, is_checkpointing_function: bool = False, deps: List[str] = ["im-1"], @@ -144,6 +184,8 @@ def _container_args( is_builder_function=is_builder_function, is_auto_snapshot=is_auto_snapshot, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, is_checkpointing_function=is_checkpointing_function, object_dependencies=[api_pb2.ObjectDependency(object_id=object_id) for object_id in deps], max_inputs=max_inputs, @@ -180,6 +222,8 @@ def _run_container( app_name: str = "", is_builder_function: bool = False, allow_concurrent_inputs: Optional[int] = None, + batch_max_size: int = 0, + batch_linger_ms: int = 0, serialized_params: Optional[bytes] = None, is_checkpointing_function: bool = False, deps: List[str] = ["im-1"], @@ -200,6 +244,8 @@ def _run_container( app_name, is_builder_function, allow_concurrent_inputs, + batch_max_size, + batch_linger_ms, serialized_params, is_checkpointing_function, deps, @@ -269,6 +315,16 @@ def _unwrap_scalar(ret: ContainerResult): return deserialize(ret.items[0].result.data, ret.client) +def _unwrap_batched_scalar(ret: ContainerResult, batch_size): + assert len(ret.items) == batch_size + outputs = [] + for item in ret.items: + assert item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + outputs.append(deserialize(item.result.data, ret.client)) + assert len(outputs) == batch_size + return outputs + + def _unwrap_exception(ret: ContainerResult): assert len(ret.items) == 1 assert ret.items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE @@ -276,6 +332,17 @@ def _unwrap_exception(ret: ContainerResult): return ret.items[0].result.exception +def _unwrap_batched_exception(ret: ContainerResult, batch_size): + assert len(ret.items) == batch_size + outputs = [] + for item in ret.items: + assert item.result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE + assert "Traceback" in item.result.traceback + outputs.append(item.result.exception) + assert len(outputs) == batch_size + return outputs + + def _unwrap_generator(ret: ContainerResult) -> Tuple[List[Any], Optional[Exception]]: assert len(ret.items) == 1 item = ret.items[0] @@ -1021,6 +1088,96 @@ def test_concurrent_inputs_async_function(servicer): assert function_call_id and function_call_id == outputs[i - 1][2] +def _batched_sync_function_helper(servicer, args_list, expected_outputs): + batch_max_size = 4 + batch_linger_ms = 500 + inputs = _get_inputs_batched(args_list, batch_max_size) + + ret = _run_container( + servicer, + "test.supports.functions", + "batched_function_sync", + inputs=inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) + + outputs = _unwrap_batched_scalar(ret, len(expected_outputs)) + assert outputs == expected_outputs + + +@skip_github_non_linux +def test_batched_sync_function(servicer): + # full batch + _batched_sync_function_helper(servicer, [((10,), {"y": 5}) for _ in range(4)], [2] * 4) + # partial batch + _batched_sync_function_helper(servicer, [((10,), {"y": 5}) for _ in range(2)], [2] * 6) + # kwarg / arg mix + _batched_sync_function_helper( + servicer, [(tuple(), {"x": 10, "y": 5}), ((10, 5), {}), ((10,), {}), (tuple(), {"x": 10})], [2] * 8 + [10] * 2 + ) + + +@skip_github_non_linux +def test_batched_sync_function_inputs_error(servicer): + args_list = [((3,), {"y": 5}) for _ in range(3)] + [(tuple(), {"y": 5})] + with pytest.raises(InvalidError) as err: + batch_max_size = 4 + batch_linger_ms = 500 + inputs = _get_inputs_batched(args_list, batch_max_size) + + _run_container( + servicer, + "test.supports.functions", + "batched_function_sync", + inputs=inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) + assert "Batched function batched_function_sync missing required positional argument" in str(err) + + +@skip_github_non_linux +def test_batched_sync_function_generic_error(servicer): + args_list = [((10,), {"y": 0}) for _ in range(4)] + batch_max_size = 4 + batch_linger_ms = 500 + inputs = _get_inputs_batched(args_list, batch_max_size) + + ret = _run_container( + servicer, + "test.supports.functions", + "batched_function_sync", + inputs=inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) + outputs = _unwrap_batched_exception(ret, batch_max_size) + for output in outputs: + assert output == "ZeroDivisionError('division by zero')" + + +@skip_github_non_linux +def test_batched_async_function(servicer): + batch_max_size = 4 + batch_linger_ms = 500 + args = [((10,), {"y": 5}) for _ in range(4)] + inputs = _get_inputs_batched(args, batch_max_size) + + ret = _run_container( + servicer, + "test.supports.functions", + "batched_function_async", + inputs=inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) + + outputs = _unwrap_batched_scalar(ret, batch_max_size) + for output in outputs: + assert output == 2 + + @skip_github_non_linux def test_unassociated_function(servicer): ret = _run_container(servicer, "test.supports.functions", "unassociated_function") diff --git a/test/supports/functions.py b/test/supports/functions.py index 4fcc9618d..e5d04bcde 100644 --- a/test/supports/functions.py +++ b/test/supports/functions.py @@ -9,6 +9,7 @@ App, Sandbox, asgi_app, + batch, build, current_function_call_id, current_input_id, @@ -330,6 +331,25 @@ async def sleep_700_async(x): return x * x, current_input_id(), current_function_call_id() +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +def batched_function_sync(x, y=1): + outputs = [] + for x_i, y_i in zip(x, y): + outputs.append(x_i / y_i) + return outputs + + +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +async def batched_function_async(x, y=1): + outputs = [] + for x_i, y_i in zip(x, y): + outputs.append(x_i / y_i) + await asyncio.sleep(0.1) + return outputs + + def unassociated_function(x): return 100 - x From bd80e7d22328080a6737fa711ca8966a7513d8b4 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Fri, 26 Jul 2024 20:40:30 +0000 Subject: [PATCH 2/8] enable class method to batch --- modal/_container_entrypoint.py | 17 +++++++-------- modal/app.py | 23 +++++++++++++++++++- modal/cls.py | 4 ---- modal/functions.py | 1 - modal/partial_function.py | 38 +++++++++++++++++----------------- test/container_test.py | 4 ++-- 6 files changed, 50 insertions(+), 37 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 5525aa0b3..91c9d7c68 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -139,13 +139,6 @@ def get_func_signature_info(func, ignore_self=False): signature_params = list(inspect.signature(func).parameters.values()) function_name = func.__name__ - if ignore_self: - if len(signature_params) == 0: - raise ValueError( - "Methods must take a 'self' argument, but the " f"method '{func.__name__}' does not have one." - ) - signature_params = signature_params[1:] - param_names_and_defaults = [(param.name, param.default) for param in signature_params] return FuncSignatureInfo(function_name, param_names_and_defaults) @@ -224,7 +217,7 @@ def get_finalized_functions( is_async=is_async, is_generator=is_generator, data_format=api_pb2.DATA_FORMAT_PICKLE, - signature_info=get_func_signature_info(bound_func, ignore_self=True), + signature_info=get_func_signature_info(bound_func), ) else: web_callable = construct_webhook_callable(bound_func, webhook_config, container_io_manager) @@ -233,7 +226,7 @@ def get_finalized_functions( is_async=True, is_generator=True, data_format=api_pb2.DATA_FORMAT_ASGI, - signature_info=get_func_signature_info(bound_func, ignore_self=True), + signature_info=get_func_signature_info(bound_func), ) finalized_functions[method_name] = finalized_function return finalized_functions @@ -367,7 +360,7 @@ def _aggregate_ids_and_args( for local_input in local_inputs ] ) - args_and_kwargs_dict = [{} for _ in input_ids] + args_and_kwargs_dict: List[Dict[str, Any]] = [{} for _ in input_ids] for i, (args, kwargs) in enumerate(zip(args_list, kwargs_list)): for j, arg in enumerate(args): if j < len(signature_info.params_and_defaults): @@ -500,6 +493,10 @@ def run_input_sync( if finalized_function.is_generator: if not inspect.isgenerator(res): raise InvalidError(f"Generator function returned value of type {type(res)}") + if isinstance(function_call_ids, list): + raise InvalidError( + f"Batched function {finalized_function.signature_info.func_name} cannot return generators." + ) # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) diff --git a/modal/app.py b/modal/app.py index d0c69f76b..8676bee19 100644 --- a/modal/app.py +++ b/modal/app.py @@ -33,7 +33,12 @@ from .mount import _Mount from .network_file_system import _NetworkFileSystem from .object import _Object -from .partial_function import _find_callables_for_cls, _PartialFunction, _PartialFunctionFlags +from .partial_function import ( + _find_callables_for_cls, + _find_partial_methods_for_user_cls, + _PartialFunction, + _PartialFunctionFlags, +) from .proxy import _Proxy from .retries import Retries from .runner import _run_app @@ -774,6 +779,20 @@ def wrapper(user_cls: CLS_T) -> _Cls: raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together") scheduler_placement = SchedulerPlacement(region=region) + batch_functions = _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.BATCH) + if batch_functions: + if ( + len(batch_functions) > 1 + or len(_find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.FUNCTION)) > 1 + ): + raise InvalidError("A class with batch functions cannot have other modal methods.") + batch_function = next(iter(batch_functions.values())) + batch_max_size = batch_function.batch_max_size + batch_linger_ms = batch_function.batch_linger_ms + else: + batch_max_size = None + batch_linger_ms = None + cls_func = _Function.from_args( info, app=self, @@ -791,6 +810,8 @@ def wrapper(user_cls: CLS_T) -> _Cls: retries=retries, concurrency_limit=concurrency_limit, allow_concurrent_inputs=allow_concurrent_inputs, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, container_idle_timeout=container_idle_timeout, timeout=timeout, cpu=cpu, diff --git a/modal/cls.py b/modal/cls.py index 13c05cb6f..52e025e17 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -359,8 +359,6 @@ def with_options( timeout: Optional[int] = None, concurrency_limit: Optional[int] = None, allow_concurrent_inputs: Optional[int] = None, - batch_max_size: Optional[int] = None, - batch_linger_ms: Optional[int] = None, container_idle_timeout: Optional[int] = None, allow_background_volume_commits: Optional[bool] = None, ) -> "_Cls": @@ -406,8 +404,6 @@ def with_options( replace_volume_mounts=replace_volume_mounts, volume_mounts=volume_mounts, allow_concurrent_inputs=allow_concurrent_inputs, - batch_max_size=batch_max_size, - batch_linger_ms=batch_linger_ms, ) return cls diff --git a/modal/functions.py b/modal/functions.py index 6bf056d26..602165229 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -519,7 +519,6 @@ def from_args( assert info.user_cls assert not webhook_config assert not schedule - assert batch_max_size is None if secret is not None: deprecation_error( diff --git a/modal/partial_function.py b/modal/partial_function.py index f2a60785a..949e4045c 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -107,17 +107,6 @@ def add_flags(self, flags) -> "_PartialFunction": batch_linger_ms=self.batch_linger_ms, ) - def set_batch_params(self, batch_max_size, batch_linger_ms) -> "_PartialFunction": - self.wrapped = True - return _PartialFunction( - raw_f=self.raw_f, - flags=(self.flags | _PartialFunctionFlags.BATCH), - webhook_config=self.webhook_config, - keep_warm=self.keep_warm, - batch_max_size=batch_max_size, - batch_linger_ms=batch_linger_ms, - ) - PartialFunction = synchronize_api(_PartialFunction) @@ -213,12 +202,18 @@ def f(self): def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: nonlocal is_generator - if isinstance(raw_f, _PartialFunction) and (raw_f.webhook_config or (raw_f.batch_max_size is not None)): + if isinstance(raw_f, _PartialFunction) and (raw_f.webhook_config): raw_f.wrapped = True # suppress later warning raise InvalidError( "Web endpoints on classes should not be wrapped by `@method`. " "Suggestion: remove the `@method` decorator." ) + if isinstance(raw_f, _PartialFunction) and (raw_f.batch_max_size is not None): + raw_f.wrapped = True # suppress later warning + raise InvalidError( + "Batched function on classes should not be wrapped by `@method`. " + "Suggestion: remove the `@method` decorator." + ) if is_generator is None: is_generator = inspect.isgeneratorfunction(raw_f) or inspect.isasyncgenfunction(raw_f) return _PartialFunction(raw_f, _PartialFunctionFlags.FUNCTION, is_generator=is_generator, keep_warm=keep_warm) @@ -594,14 +589,19 @@ def _batch( if batch_linger_ms > MAX_BATCH_LINGER_MS: raise InvalidError(f"batch_linger_ms must be less than or equal to {MAX_BATCH_LINGER_MS}.") - def wrapper(f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: - if isinstance(f, _PartialFunction): - _disallow_wrapping_method(f, "batch") - return f.set_batch_params(batch_max_size, batch_linger_ms) - else: - return _PartialFunction( - f, _PartialFunctionFlags.BATCH, batch_max_size=batch_max_size, batch_linger_ms=batch_linger_ms + def wrapper(raw_f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: + if isinstance(raw_f, _Function): + raw_f = raw_f.get_raw_f() + raise InvalidError( + f"Applying decorators for {raw_f} in the wrong order!\nUsage:\n\n" + "@app.function()\n@modal.batch()\ndef batched_function():\n ..." ) + return _PartialFunction( + raw_f, + _PartialFunctionFlags.FUNCTION | _PartialFunctionFlags.BATCH, + batch_max_size=batch_max_size, + batch_linger_ms=batch_linger_ms, + ) return wrapper diff --git a/test/container_test.py b/test/container_test.py index 03661a3d0..c60901056 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -84,7 +84,7 @@ def _get_inputs_batched( *([api_pb2.FunctionGetInputsItem(kill_switch=True)] if kill_switch else []), ] response_list = [] - current_batch = [] + current_batch: List[Any] = [] while inputs: input = inputs.pop(0) if input.kill_switch: @@ -1120,7 +1120,7 @@ def test_batched_sync_function(servicer): @skip_github_non_linux def test_batched_sync_function_inputs_error(servicer): - args_list = [((3,), {"y": 5}) for _ in range(3)] + [(tuple(), {"y": 5})] + args_list: List[Any] = [((3,), {"y": 5}) for _ in range(3)] + [(tuple(), {"y": 5})] with pytest.raises(InvalidError) as err: batch_max_size = 4 batch_linger_ms = 500 From 2b8906b540691c63fa1f0003d716f4a5b27007f2 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 29 Jul 2024 12:45:52 +0000 Subject: [PATCH 3/8] fix current_id and deserialization latency --- modal/_container_entrypoint.py | 11 ++--------- modal/_container_io_manager.py | 13 +++++++------ modal/execution_context.py | 13 ++++++++++--- modal/functions.py | 6 +++--- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 91c9d7c68..df9cce654 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -354,12 +354,7 @@ def _aggregate_ids_and_args( [param for param in signature_info.params_and_defaults if param[1] == inspect.Parameter.empty] ) num_args = len(signature_info.params_and_defaults) - args_list, kwargs_list = zip( - *[ - container_io_manager.deserialize(local_input.input_args) if local_input.input_args else ((), {}) - for local_input in local_inputs - ] - ) + args_list, kwargs_list = zip(*[(local_input.args, local_input.kwargs) for local_input in local_inputs]) args_and_kwargs_dict: List[Dict[str, Any]] = [{} for _ in input_ids] for i, (args, kwargs) in enumerate(zip(args_list, kwargs_list)): for j, arg in enumerate(args): @@ -407,9 +402,7 @@ def _aggregate_ids_and_args( else: input_id = local_inputs.input_id function_call_id = local_inputs.function_call_id - args, kwargs = ( - container_io_manager.deserialize(local_inputs.input_args) if local_inputs.input_args else ((), {}) - ) + args, kwargs = local_inputs.args, local_inputs.kwargs return input_id, function_call_id, args, kwargs diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 3d52c1f59..0bb8df52f 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -9,7 +9,7 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple, Union +from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -47,7 +47,8 @@ class LocalInput: input_id: str function_call_id: str method_name: str - input_args: Any + args: Tuple[Any, ...] + kwargs: Dict[str, Any] class _ContainerIOManager: @@ -459,17 +460,17 @@ async def run_inputs_outputs( if batch_max_size == 0: async for input_id, function_call_id, input_pb in self._generate_inputs(): + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) self.current_input_id, self.current_input_started_at = (input_id, time.time()) - yield LocalInput(input_id, function_call_id, input_pb.method_name, input_pb.args) + yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) self.current_input_id, self.current_input_started_at = (None, None) else: async for inputs_list in self._generate_inputs(): local_inputs_list = [] for input_id, function_call_id, input_pb in inputs_list: + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) self.current_input_id, self.current_input_started_at = (input_id, time.time()) - local_inputs_list.append( - LocalInput(input_id, function_call_id, input_pb.method_name, input_pb.args) - ) + local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) self.current_input_id, self.current_input_started_at = (None, None) yield local_inputs_list diff --git a/modal/execution_context.py b/modal/execution_context.py index e45cab899..bf1bd399d 100644 --- a/modal/execution_context.py +++ b/modal/execution_context.py @@ -32,7 +32,7 @@ async def _interact() -> None: interact = synchronize_api(_interact) -def current_input_id() -> Optional[Union[str, List[str]]]: +def current_input_id() -> Optional[str]: """Returns the input ID for the current input. Can only be called from Modal function (i.e. in a container context). @@ -73,8 +73,15 @@ def process_stuff(): def _set_current_context_ids( input_ids: Union[str, List[str]], function_call_ids: Union[str, List[str]] ) -> Callable[[], None]: - input_token = _current_input_id.set(input_ids) - function_call_token = _current_function_call_id.set(function_call_ids) + if isinstance(input_ids, list): + assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0 + input_id = input_ids[0] + function_call_id = function_call_ids[0] + else: + input_id = input_ids + function_call_id = function_call_ids + input_token = _current_input_id.set(input_id) + function_call_token = _current_function_call_id.set(function_call_id) def _reset_current_context_ids(): _current_input_id.reset(input_token) diff --git a/modal/functions.py b/modal/functions.py index 602165229..dda5e4eb0 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -344,8 +344,8 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing is_method=True, use_function_id=class_service_function.object_id, use_method_name=method_name, - batch_max_size=partial_function.batch_max_size, - batch_linger_ms=partial_function.batch_linger_ms, + batch_max_size=partial_function.batch_max_size or 0, + batch_linger_ms=partial_function.batch_linger_ms or 0, ) assert resolver.app_id request = api_pb2.FunctionCreateRequest( @@ -487,7 +487,7 @@ def from_args( concurrency_limit: Optional[int] = None, allow_concurrent_inputs: Optional[int] = None, batch_max_size: Optional[int] = None, - batch_linger_ms: Optional[float] = None, + batch_linger_ms: Optional[int] = None, container_idle_timeout: Optional[int] = None, cpu: Optional[float] = None, keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1 From c6d8ff00316254abeb9f4b870f5ee69edbeab936 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 31 Jul 2024 16:51:07 +0000 Subject: [PATCH 4/8] cleanup and unit test --- modal/_container_entrypoint.py | 159 ++++++++++++++++----------------- modal/_container_io_manager.py | 50 +++++++---- modal/app.py | 11 +-- modal/functions.py | 2 +- modal/partial_function.py | 14 +-- test/cls_test.py | 33 +++++++ test/container_test.py | 127 +++++++++++++------------- test/decorator_test.py | 15 +++- test/function_test.py | 25 +++++- test/supports/functions.py | 18 +++- 10 files changed, 271 insertions(+), 183 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index df9cce654..6332a35a0 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -107,7 +107,6 @@ class FinalizedFunction: is_async: bool is_generator: bool data_format: int # api_pb2.DataFormat - signature_info: "FuncSignatureInfo" class Service(metaclass=ABCMeta): @@ -129,20 +128,6 @@ def get_finalized_functions( ... -@dataclass -class FuncSignatureInfo: - func_name: str - params_and_defaults: List[Tuple[str, Any]] - - -def get_func_signature_info(func, ignore_self=False): - signature_params = list(inspect.signature(func).parameters.values()) - function_name = func.__name__ - - param_names_and_defaults = [(param.name, param.default) for param in signature_params] - return FuncSignatureInfo(function_name, param_names_and_defaults) - - @dataclass class ImportedFunction(Service): user_cls_instance: Any @@ -168,7 +153,6 @@ def get_finalized_functions( is_async=is_async, is_generator=is_generator, data_format=api_pb2.DATA_FORMAT_PICKLE, - signature_info=get_func_signature_info(self._user_defined_callable), ) } @@ -182,7 +166,6 @@ def get_finalized_functions( is_async=True, is_generator=True, data_format=api_pb2.DATA_FORMAT_ASGI, - signature_info=get_func_signature_info(self._user_defined_callable), ) } @@ -217,7 +200,6 @@ def get_finalized_functions( is_async=is_async, is_generator=is_generator, data_format=api_pb2.DATA_FORMAT_PICKLE, - signature_info=get_func_signature_info(bound_func), ) else: web_callable = construct_webhook_callable(bound_func, webhook_config, container_io_manager) @@ -226,7 +208,6 @@ def get_finalized_functions( is_async=True, is_generator=True, data_format=api_pb2.DATA_FORMAT_ASGI, - signature_info=get_func_signature_info(bound_func), ) finalized_functions[method_name] = finalized_function return finalized_functions @@ -342,68 +323,44 @@ def _sigint_handler(): self.loop.remove_signal_handler(signal.SIGINT) -def _aggregate_ids_and_args( +def _aggregate_args_and_kwargs( local_inputs: Union[LocalInput, List[LocalInput]], - container_io_manager: "modal._container_io_manager.ContainerIOManager", - signature_info: FuncSignatureInfo, -) -> Tuple[Union[str, List[str]], Union[str, List[str]], List[Any], Dict[str, Any]]: + callable: Callable[..., Any], +) -> Tuple[List[Any], Dict[str, Any]]: if isinstance(local_inputs, list): - input_ids = [local_input.input_id for local_input in local_inputs] - function_call_ids = [local_input.function_call_id for local_input in local_inputs] - num_required_args = len( - [param for param in signature_info.params_and_defaults if param[1] == inspect.Parameter.empty] - ) - num_args = len(signature_info.params_and_defaults) - args_list, kwargs_list = zip(*[(local_input.args, local_input.kwargs) for local_input in local_inputs]) - args_and_kwargs_dict: List[Dict[str, Any]] = [{} for _ in input_ids] - for i, (args, kwargs) in enumerate(zip(args_list, kwargs_list)): - for j, arg in enumerate(args): - if j < len(signature_info.params_and_defaults): - args_and_kwargs_dict[i][signature_info.params_and_defaults[j][0]] = arg - else: - raise ( - InvalidError( - f"Batched function {signature_info.func_name} takes {num_required_args} \ - positional arguments but {len(args)} were given" - ) - if num_args == num_required_args - else InvalidError( - f"Batched function {signature_info.func_name} takes from {num_required_args} to {num_args} \ - positional arguments but {len(args)} were given" - ) + param_names = list(inspect.signature(callable).parameters.keys()) + for param in inspect.signature(callable).parameters.values(): + if param.default is not inspect.Parameter.empty: + raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") + + args_by_inputs = [{} for _ in range(len(local_inputs))] + for i, local_input in enumerate(local_inputs): + params_len = len(local_input.args) + len(local_input.kwargs) + if params_len != len(param_names): + raise InvalidError( + f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, \ + but one call has {params_len}." + ) + for j, arg in enumerate(local_input.args): + args_by_inputs[i][param_names[j]] = arg + for k, v in local_input.kwargs.items(): + if k not in param_names: + raise InvalidError( + f"Modal batch function {callable.__name__} got an unexpected keyword argument {k} in one call." ) - for k, v in kwargs.items(): - if k in [param[0] for param in signature_info.params_and_defaults]: - if k in args_and_kwargs_dict[i]: - raise InvalidError( - f"Batched function {signature_info.func_name} got multiple values for argument {k}" - ) - args_and_kwargs_dict[i][k] = v - else: + if k in args_by_inputs[i]: raise InvalidError( - f"Batched function {signature_info.func_name} got an unexpected keyword argument: {k}" + f"Modal batch function {callable.__name__} got multiple values for argument {k} in one call." ) - - formatted_args = [] - for arg_name, _ in signature_info.params_and_defaults[:num_required_args]: - if any(arg_name not in args_and_kwargs_dict[i] for i in range(len(input_ids))): - raise InvalidError( - f"Batched function {signature_info.func_name} missing required positional argument {arg_name}" - ) - formatted_args.append([args_and_kwargs_dict[i][arg_name] for i in range(len(input_ids))]) + args_by_inputs[i][k] = v formatted_kwargs = { - kwarg_name: [args_and_kwargs_dict[j].get(kwarg_name, default) for j in range(len(input_ids))] - for kwarg_name, default in signature_info.params_and_defaults[num_required_args:] + param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names } - - return input_ids, function_call_ids, formatted_args, formatted_kwargs + return tuple(), formatted_kwargs else: - input_id = local_inputs.input_id - function_call_id = local_inputs.function_call_id - args, kwargs = local_inputs.args, local_inputs.kwargs - return input_id, function_call_id, args, kwargs + return local_inputs.args, local_inputs.kwargs def call_function( @@ -420,11 +377,21 @@ async def run_input_async( container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() - input_ids, function_call_ids, args, kwargs = _aggregate_ids_and_args( - local_inputs, container_io_manager, finalized_function.signature_info + input_ids = ( + local_inputs.input_id + if isinstance(local_inputs, LocalInput) + else [local_input.input_id for local_input in local_inputs] ) + function_call_ids = ( + local_inputs.function_call_id + if isinstance(local_inputs, LocalInput) + else [local_input.function_call_id for local_input in local_inputs] + ) + args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) reset_context = _set_current_context_ids(input_ids, function_call_ids) - async with container_io_manager.handle_input_exception.aio(input_ids, started_at): + async with container_io_manager.handle_input_exception.aio( + input_ids, started_at, finalized_function.callable.__name__ + ): logger.debug(f"Starting input {input_ids} (async)") res = finalized_function.callable(*args, **kwargs) logger.debug(f"Finished input {input_ids} (async)") @@ -434,7 +401,9 @@ async def run_input_async( if not inspect.isasyncgen(res): raise InvalidError(f"Async generator function returned value of type {type(res)}") if isinstance(input_ids, list) or isinstance(function_call_ids, list): - raise InvalidError("Batched functions cannot return generators.") + raise InvalidError( + f"Batch function {finalized_function.callable.__name__} cannot return a generator." + ) # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024) @@ -455,7 +424,11 @@ async def run_input_async( await generator_output_task # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) await container_io_manager.push_output.aio( - input_ids, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE + input_ids, + started_at, + finalized_function.callable.__name__, + message, + api_pb2.DATA_FORMAT_GENERATOR_DONE, ) else: if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): @@ -464,7 +437,9 @@ async def run_input_async( " You might need to use @app.function(..., is_generator=True)." ) value = await res - await container_io_manager.push_output.aio(input_ids, started_at, value, finalized_function.data_format) + await container_io_manager.push_output.aio( + input_ids, started_at, finalized_function.callable.__name__, value, finalized_function.data_format + ) reset_context() def run_input_sync( @@ -473,11 +448,19 @@ def run_input_sync( container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() - input_ids, function_call_ids, args, kwargs = _aggregate_ids_and_args( - local_inputs, container_io_manager, finalized_function.signature_info + input_ids = ( + local_inputs.input_id + if isinstance(local_inputs, LocalInput) + else [local_input.input_id for local_input in local_inputs] + ) + function_call_ids = ( + local_inputs.function_call_id + if isinstance(local_inputs, LocalInput) + else [local_input.function_call_id for local_input in local_inputs] ) + args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) reset_context = _set_current_context_ids(input_ids, function_call_ids) - with container_io_manager.handle_input_exception(input_ids, started_at): + with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__): logger.debug(f"Starting input {input_ids} (sync)") res = finalized_function.callable(*args, **kwargs) logger.debug(f"Finished input {input_ids} (sync)") @@ -488,7 +471,7 @@ def run_input_sync( raise InvalidError(f"Generator function returned value of type {type(res)}") if isinstance(function_call_ids, list): raise InvalidError( - f"Batched function {finalized_function.signature_info.func_name} cannot return generators." + f"Batch function {finalized_function.callable.__name__} cannot return generators." ) # Send up to this many outputs at a time. @@ -508,20 +491,28 @@ def run_input_sync( container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) generator_output_task.result() # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) - container_io_manager.push_output(input_ids, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) + container_io_manager.push_output( + input_ids, + started_at, + finalized_function.callable.__name__, + message, + api_pb2.DATA_FORMAT_GENERATOR_DONE, + ) else: if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): raise InvalidError( f"Sync (non-generator) function return value of type {type(res)}." " You might need to use @app.function(..., is_generator=True)." ) - container_io_manager.push_output(input_ids, started_at, res, finalized_function.data_format) + container_io_manager.push_output( + input_ids, started_at, finalized_function.callable.__name__, res, finalized_function.data_format + ) reset_context() def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction: if isinstance(local_inputs, list): assert len(local_inputs) > 0 - assert all(local_input.method_name == local_inputs[0].method_name for local_input in local_inputs) + assert len(set([local_input.method_name for local_input in local_inputs])) == 1 return finalized_functions[local_inputs[0].method_name] else: return finalized_functions[local_inputs.method_name] diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 0bb8df52f..115430a9b 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -405,7 +405,6 @@ async def _generate_inputs( if item.input.final_input or self.function_def.max_inputs == 1: eof_received = True break - else: assert len(response.inputs) <= request.batch_max_size @@ -427,7 +426,7 @@ async def _generate_inputs( inputs_list.append((item.input_id, item.function_call_id, input_pb)) if item.input.final_input: eof_received = True - logger.error("Final input not expected in batched input stream") + logger.error("Final input not expected in batch input stream") break if not eof_received: @@ -471,28 +470,41 @@ async def run_inputs_outputs( args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) self.current_input_id, self.current_input_started_at = (input_id, time.time()) local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) - self.current_input_id, self.current_input_started_at = (None, None) yield local_inputs_list + self.current_input_id, self.current_input_started_at = (None, None) # collect all active input slots, meaning all inputs have wrapped up. for _ in range(input_concurrency): await self._semaphore.acquire() async def _push_output( - self, input_ids: Union[str, List[str]], started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs + self, + input_ids: Union[str, List[str]], + started_at: float, + function_name: str, + data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, + **kwargs, ): outputs = [] if isinstance(input_ids, list): - data_list = [] + formatted_data = None if "data" in kwargs and kwargs["data"]: # split the list of data in kwargs to respective input_ids - data_list = self.deserialize_data_format(kwargs.pop("data"), data_format) - assert isinstance(data_list, list), "Output of batched function must be a list" - assert len(data_list) == len( - input_ids - ), "Output of batched function must be a list of the same length as its list of inputs." + # report error for every input_id in batch call + if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: + formatted_data = [kwargs.pop("data")] * len(input_ids) + else: + data = self.deserialize_data_format(kwargs.pop("data"), data_format) + if not isinstance(data, list): + raise InvalidError(f"Output of batch function {function_name} must be a list.") + if len(data) != len(input_ids): + raise InvalidError( + f"Output of batch function {function_name} must be \ + a list of the same length as its list of inputs." + ) + formatted_data = [self.serialize_data_format(d, data_format) for d in data] for i, input_id in enumerate(input_ids): - data = self.serialize_data_format(data_list[i], data_format) if data_list else None + data = formatted_data[i] if formatted_data else None result = ( ( # upload data to S3 if too big. @@ -592,7 +604,7 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: @asynccontextmanager async def handle_input_exception( - self, input_ids: Union[str, List[str]], started_at: float + self, input_ids: Union[str, List[str]], started_at: float, function_name: str ) -> AsyncGenerator[None, None]: """Handle an exception while processing a function input.""" try: @@ -611,6 +623,10 @@ async def handle_input_exception( logger.warning(f"The current input ({input_ids=}) was cancelled by a user request") await self.complete_call(started_at) return + except InvalidError as exc: + # If there is an error in batch function output, we need to explicitly raise it + if "Output of batch function" in exc.args[0]: + raise except BaseException as exc: # print exception so it's logged traceback.print_exc() @@ -634,13 +650,10 @@ async def handle_input_exception( await self._push_output( input_ids, started_at=started_at, + function_name=function_name, data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, - data=( - self.serialize_exception([exc] for _ in input_ids) - if isinstance(input_ids, list) - else self.serialize_exception(exc) - ), + data=self.serialize_exception(exc), exception=repr_exc, traceback=traceback.format_exc(), serialized_tb=serialized_tb, @@ -654,10 +667,11 @@ async def complete_call(self, started_at): self._semaphore.release() @synchronizer.no_io_translation - async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None: + async def push_output(self, input_id, started_at: float, function_name: str, data: Any, data_format: int) -> None: await self._push_output( input_id, started_at=started_at, + function_name=function_name, data_format=data_format, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, data=self.serialize_data_format(data, data_format), diff --git a/modal/app.py b/modal/app.py index 8676bee19..f6efca1e2 100644 --- a/modal/app.py +++ b/modal/app.py @@ -781,11 +781,12 @@ def wrapper(user_cls: CLS_T) -> _Cls: batch_functions = _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.BATCH) if batch_functions: - if ( - len(batch_functions) > 1 - or len(_find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.FUNCTION)) > 1 - ): - raise InvalidError("A class with batch functions cannot have other modal methods.") + if len(batch_functions) > 1: + raise InvalidError(f"Modal class {user_cls.__name__} can only have one batch function.") + if len(_find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.FUNCTION)) > 1: + raise InvalidError( + f"Modal class {user_cls.__name__} with a modal batch function cannot have other modal methods." + ) batch_function = next(iter(batch_functions.values())) batch_max_size = batch_function.batch_max_size batch_linger_ms = batch_function.batch_linger_ms diff --git a/modal/functions.py b/modal/functions.py index dda5e4eb0..1d5dca34c 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -657,7 +657,7 @@ def from_args( else: raise InvalidError("Webhooks cannot be generators") if is_generator and batch_max_size: - raise InvalidError("Batched functions cannot be generators") + raise InvalidError(f"Batch functions {info.raw_f.__name__} cannot return a generator") if container_idle_timeout is not None and container_idle_timeout <= 0: raise InvalidError("`container_idle_timeout` must be > 0") diff --git a/modal/partial_function.py b/modal/partial_function.py index 949e4045c..2b1f65ff4 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -20,8 +20,8 @@ from .exception import InvalidError, deprecation_error, deprecation_warning from .functions import _Function -MAX_BATCH_SIZE = 49 -MAX_BATCH_LINGER_MS = 12000 +MAX_BATCH_SIZE = 1000 +MAX_BATCH_LINGER_MS = 10 * 60 * 1000 # 10 minutes class _PartialFunctionFlags(enum.IntFlag): @@ -211,7 +211,7 @@ def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction: if isinstance(raw_f, _PartialFunction) and (raw_f.batch_max_size is not None): raw_f.wrapped = True # suppress later warning raise InvalidError( - "Batched function on classes should not be wrapped by `@method`. " + "Batch function on classes should not be wrapped by `@method`. " "Suggestion: remove the `@method` decorator." ) if is_generator is None: @@ -582,12 +582,12 @@ def _batch( raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batch()`.") if batch_max_size < 1: raise InvalidError("batch_max_size must be a positive integer.") - if batch_max_size > MAX_BATCH_SIZE: - raise InvalidError(f"batch_max_size must be less than or equal to {MAX_BATCH_SIZE}.") + if batch_max_size >= MAX_BATCH_SIZE: + raise InvalidError(f"batch_max_size must be less than {MAX_BATCH_SIZE}.") if batch_linger_ms < 0: raise InvalidError("batch_linger_ms must be a non-negative integer.") - if batch_linger_ms > MAX_BATCH_LINGER_MS: - raise InvalidError(f"batch_linger_ms must be less than or equal to {MAX_BATCH_LINGER_MS}.") + if batch_linger_ms >= MAX_BATCH_LINGER_MS: + raise InvalidError(f"batch_linger_ms must be less than {MAX_BATCH_LINGER_MS}.") def wrapper(raw_f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction: if isinstance(raw_f, _Function): diff --git a/test/cls_test.py b/test/cls_test.py index fe38ca804..0faf4f86f 100644 --- a/test/cls_test.py +++ b/test/cls_test.py @@ -682,6 +682,7 @@ def __exit__(self, exc_type, exc, tb): _find_callables_for_obj(obj, _PartialFunctionFlags.EXIT) with pytest.raises(DeprecationError, match="Support for decorating parameterized methods with `@exit`"): + class ClsWithDeprecatedSyncExitMethod: @exit() def my_exit(self, exc_type, exc, tb): @@ -710,6 +711,7 @@ async def __aexit__(self, exc_type, exc, tb): _find_callables_for_obj(obj, _PartialFunctionFlags.EXIT) with pytest.raises(DeprecationError, match="Support for decorating parameterized methods with `@exit`"): + class ClsWithDeprecatedAsyncExitMethod: @exit() async def my_exit(self, exc_type, exc, tb): @@ -860,3 +862,34 @@ def test_disabled_parameterized_snap_cls(): app.cls(enable_memory_snapshot=True)(ParameterizedClass2) app.cls(enable_memory_snapshot=True)(ParameterizedClass3) + + +app_batch = App() + + +def test_batch_method_duplicate_error(client): + with pytest.raises( + InvalidError, match="Modal class BatchClass_1 with a modal batch function cannot have other modal methods." + ): + + @app_batch.cls(serialized=True) + class BatchClass_1: + @modal.method() + def method(self): + pass + + @modal.batch(batch_max_size=2, batch_linger_ms=0) + def batch_method(self): + pass + + with pytest.raises(InvalidError, match="Modal class BatchClass_2 can only have one batch function."): + + @app_batch.cls(serialized=True) + class BatchClass_2: + @modal.batch(batch_max_size=2, batch_linger_ms=0) + def batch_method_1(self): + pass + + @modal.batch(batch_max_size=2, batch_linger_ms=0) + def batch_method_2(self): + pass diff --git a/test/container_test.py b/test/container_test.py index c60901056..8a305a477 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -64,7 +64,7 @@ def _get_inputs( return [api_pb2.FunctionGetInputsResponse(inputs=[x]) for x in inputs] -def _get_inputs_batched( +def _get_inputs_batch( args_list: List[Tuple[Tuple, Dict]], batch_max_size: int, kill_switch=True, @@ -315,7 +315,7 @@ def _unwrap_scalar(ret: ContainerResult): return deserialize(ret.items[0].result.data, ret.client) -def _unwrap_batched_scalar(ret: ContainerResult, batch_size): +def _unwrap_batch_scalar(ret: ContainerResult, batch_size): assert len(ret.items) == batch_size outputs = [] for item in ret.items: @@ -332,7 +332,7 @@ def _unwrap_exception(ret: ContainerResult): return ret.items[0].result.exception -def _unwrap_batched_exception(ret: ContainerResult, batch_size): +def _unwrap_batch_exception(ret: ContainerResult, batch_size): assert len(ret.items) == batch_size outputs = [] for item in ret.items: @@ -1088,94 +1088,95 @@ def test_concurrent_inputs_async_function(servicer): assert function_call_id and function_call_id == outputs[i - 1][2] -def _batched_sync_function_helper(servicer, args_list, expected_outputs): +def _batch_function_test_helper(batch_func, servicer, args_list, expected_outputs, expected_status="success"): batch_max_size = 4 batch_linger_ms = 500 - inputs = _get_inputs_batched(args_list, batch_max_size) + inputs = _get_inputs_batch(args_list, batch_max_size) ret = _run_container( servicer, "test.supports.functions", - "batched_function_sync", + batch_func, inputs=inputs, batch_max_size=batch_max_size, batch_linger_ms=batch_linger_ms, ) - - outputs = _unwrap_batched_scalar(ret, len(expected_outputs)) + if expected_status == "success": + outputs = _unwrap_batch_scalar(ret, len(expected_outputs)) + else: + outputs = _unwrap_batch_exception(ret, len(expected_outputs)) assert outputs == expected_outputs @skip_github_non_linux -def test_batched_sync_function(servicer): - # full batch - _batched_sync_function_helper(servicer, [((10,), {"y": 5}) for _ in range(4)], [2] * 4) - # partial batch - _batched_sync_function_helper(servicer, [((10,), {"y": 5}) for _ in range(2)], [2] * 6) - # kwarg / arg mix - _batched_sync_function_helper( - servicer, [(tuple(), {"x": 10, "y": 5}), ((10, 5), {}), ((10,), {}), (tuple(), {"x": 10})], [2] * 8 + [10] * 2 - ) +def test_batch_sync_function_full_batch(servicer): + inputs = [((10, 5), {}) for _ in range(4)] + expected_outputs = [2] * 4 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) @skip_github_non_linux -def test_batched_sync_function_inputs_error(servicer): - args_list: List[Any] = [((3,), {"y": 5}) for _ in range(3)] + [(tuple(), {"y": 5})] - with pytest.raises(InvalidError) as err: - batch_max_size = 4 - batch_linger_ms = 500 - inputs = _get_inputs_batched(args_list, batch_max_size) +def test_batch_sync_function_partial_batch(servicer): + inputs = [((10, 5), {}) for _ in range(2)] + expected_outputs = [2] * 2 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) - _run_container( - servicer, - "test.supports.functions", - "batched_function_sync", - inputs=inputs, - batch_max_size=batch_max_size, - batch_linger_ms=batch_linger_ms, - ) - assert "Batched function batched_function_sync missing required positional argument" in str(err) + +def test_batch_sync_function_keyword_args(servicer): + inputs = [((10,), {"y": 5}) for _ in range(4)] + expected_outputs = [2] * 4 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) @skip_github_non_linux -def test_batched_sync_function_generic_error(servicer): - args_list = [((10,), {"y": 0}) for _ in range(4)] - batch_max_size = 4 - batch_linger_ms = 500 - inputs = _get_inputs_batched(args_list, batch_max_size) +def test_batch_sync_function_inputs_outputs_error(servicer): + # argument length does not match + inputs = [((10, 5), {}), ((10, 5, 1), {})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_sync", servicer, inputs, []) + assert "Modal batch function batch_function_sync takes 2 positional arguments, but one call has 3." in str(err) - ret = _run_container( - servicer, - "test.supports.functions", - "batched_function_sync", - inputs=inputs, - batch_max_size=batch_max_size, - batch_linger_ms=batch_linger_ms, + # Unexpected keyword arg + inputs = [((10, 5), {}), ((10,), {"z": 5})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_sync", servicer, inputs, []) + assert "Modal batch function batch_function_sync got an unexpected keyword argument z in one call." in str(err) + + # Multiple values with keyword arg + inputs = [((10, 5), {}), ((10,), {"x": 1})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_sync", servicer, inputs, []) + assert "Modal batch function batch_function_sync got multiple values for argument x in one call." in str(err) + + # output must be list + inputs = [((10, 5), {})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_outputs_not_list", servicer, inputs, []) + assert "Output of batch function batch_function_outputs_not_list must be a list." in str(err) + + # outputs must match length of inputs + inputs = [((10, 5), {})] + with pytest.raises(InvalidError) as err: + _batch_function_test_helper("batch_function_outputs_wrong_len", servicer, inputs, []) + assert ( + "Output of batch function batch_function_outputs_wrong_len must be \ + a list of the same length as its list of inputs." + in str(err) ) - outputs = _unwrap_batched_exception(ret, batch_max_size) - for output in outputs: - assert output == "ZeroDivisionError('division by zero')" @skip_github_non_linux -def test_batched_async_function(servicer): - batch_max_size = 4 - batch_linger_ms = 500 - args = [((10,), {"y": 5}) for _ in range(4)] - inputs = _get_inputs_batched(args, batch_max_size) +def test_batch_sync_function_generic_error(servicer): + inputs = [((10, 0), {}) for _ in range(4)] + expected_ouputs = ["ZeroDivisionError('division by zero')"] * 4 + _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_ouputs, expected_status="failure") - ret = _run_container( - servicer, - "test.supports.functions", - "batched_function_async", - inputs=inputs, - batch_max_size=batch_max_size, - batch_linger_ms=batch_linger_ms, - ) - outputs = _unwrap_batched_scalar(ret, batch_max_size) - for output in outputs: - assert output == 2 +@skip_github_non_linux +def test_batch_async_function(servicer): + inputs = [((10, 5), {}) for _ in range(4)] + expected_outputs = [2] * 4 + _batch_function_test_helper("batch_function_async", servicer, inputs, expected_outputs) @skip_github_non_linux diff --git a/test/decorator_test.py b/test/decorator_test.py index a837a5202..0db1c7494 100644 --- a/test/decorator_test.py +++ b/test/decorator_test.py @@ -1,7 +1,7 @@ # Copyright Modal Labs 2023 import pytest -from modal import App, asgi_app, method, web_endpoint, wsgi_app +from modal import App, asgi_app, batch, method, web_endpoint, wsgi_app from modal.exception import InvalidError @@ -83,3 +83,16 @@ class Container: @web_endpoint() def generate(self): pass + + +def test_batch_method(): + app = App() + + with pytest.raises(InvalidError, match="remove the `@method`"): + + @app.cls() + class Container: + @method() # type: ignore + @batch(batch_max_size=2, batch_linger_ms=0) + def generate(self): + pass diff --git a/test/function_test.py b/test/function_test.py index 0fee2eaed..d73b351c5 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -10,7 +10,7 @@ from synchronicity.exceptions import UserCodeException import modal -from modal import App, Image, Mount, NetworkFileSystem, Proxy, web_endpoint +from modal import App, Image, Mount, NetworkFileSystem, Proxy, batch, web_endpoint from modal._utils.async_utils import synchronize_api from modal._vendor import cloudpickle from modal.exception import ExecutionError, InvalidError @@ -802,3 +802,26 @@ def test_function_decorator_on_method(): with pytest.raises(InvalidError, match="@app.cls"): app.function()(X.f) + + +def test_batch_function_invalid_error(): + app = App() + + with pytest.raises(InvalidError, match="must be a positive integer"): + app.function(batch(batch_max_size=0, batch_linger_ms=1))(dummy) + + with pytest.raises(InvalidError, match="must be a non-negative integer"): + app.function(batch(batch_max_size=1, batch_linger_ms=-1))(dummy) + + with pytest.raises(InvalidError, match="must be less than"): + app.function(batch(batch_max_size=1000, batch_linger_ms=1))(dummy) + + with pytest.raises(InvalidError, match="must be less than"): + app.function(batch(batch_max_size=1, batch_linger_ms=10 * 60 * 1000))(dummy) + + with pytest.raises(InvalidError, match="cannot return a generator"): + + @app.function(serialized=True) + @batch(batch_max_size=1, batch_linger_ms=1) + def f(x): + yield [x_i**2 for x_i in x] diff --git a/test/supports/functions.py b/test/supports/functions.py index e5d04bcde..dba51eab8 100644 --- a/test/supports/functions.py +++ b/test/supports/functions.py @@ -3,7 +3,7 @@ import asyncio import time -from typing import List +from typing import List, Tuple from modal import ( App, @@ -333,7 +333,7 @@ async def sleep_700_async(x): @app.function() @batch(batch_max_size=4, batch_linger_ms=500) -def batched_function_sync(x, y=1): +def batch_function_sync(x: Tuple[int], y: Tuple[int]): outputs = [] for x_i, y_i in zip(x, y): outputs.append(x_i / y_i) @@ -342,7 +342,19 @@ def batched_function_sync(x, y=1): @app.function() @batch(batch_max_size=4, batch_linger_ms=500) -async def batched_function_async(x, y=1): +def batch_function_outputs_not_list(x: Tuple[int], y: Tuple[int]): + return str(x) + + +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +def batch_function_outputs_wrong_len(x: Tuple[int], y: Tuple[int]): + return list(x) + [0] + + +@app.function() +@batch(batch_max_size=4, batch_linger_ms=500) +async def batch_function_async(x: Tuple[int], y: Tuple[int]): outputs = [] for x_i, y_i in zip(x, y): outputs.append(x_i / y_i) From 96d7f2aed40782b0e4df2046c96ada190ae92440 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 31 Jul 2024 17:42:33 +0000 Subject: [PATCH 5/8] fix type check --- modal/_container_entrypoint.py | 15 +++++++-------- modal/_container_io_manager.py | 3 +-- test/container_test.py | 15 +++++++-------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 6332a35a0..e47ea6b5f 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -326,20 +326,19 @@ def _sigint_handler(): def _aggregate_args_and_kwargs( local_inputs: Union[LocalInput, List[LocalInput]], callable: Callable[..., Any], -) -> Tuple[List[Any], Dict[str, Any]]: +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: if isinstance(local_inputs, list): param_names = list(inspect.signature(callable).parameters.keys()) for param in inspect.signature(callable).parameters.values(): if param.default is not inspect.Parameter.empty: raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") - args_by_inputs = [{} for _ in range(len(local_inputs))] + args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))] for i, local_input in enumerate(local_inputs): params_len = len(local_input.args) + len(local_input.kwargs) if params_len != len(param_names): raise InvalidError( - f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, \ - but one call has {params_len}." + f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa ) for j, arg in enumerate(local_input.args): args_by_inputs[i][param_names[j]] = arg @@ -377,12 +376,12 @@ async def run_input_async( container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() - input_ids = ( + input_ids: Union[str, List[str]] = ( local_inputs.input_id if isinstance(local_inputs, LocalInput) else [local_input.input_id for local_input in local_inputs] ) - function_call_ids = ( + function_call_ids: Union[str, List[str]] = ( local_inputs.function_call_id if isinstance(local_inputs, LocalInput) else [local_input.function_call_id for local_input in local_inputs] @@ -448,12 +447,12 @@ def run_input_sync( container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() - input_ids = ( + input_ids: Union[str, List[str]] = ( local_inputs.input_id if isinstance(local_inputs, LocalInput) else [local_input.input_id for local_input in local_inputs] ) - function_call_ids = ( + function_call_ids: Union[str, List[str]] = ( local_inputs.function_call_id if isinstance(local_inputs, LocalInput) else [local_input.function_call_id for local_input in local_inputs] diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 115430a9b..65f9c9d44 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -499,8 +499,7 @@ async def _push_output( raise InvalidError(f"Output of batch function {function_name} must be a list.") if len(data) != len(input_ids): raise InvalidError( - f"Output of batch function {function_name} must be \ - a list of the same length as its list of inputs." + f"Output of batch function {function_name} must be a list of the same length as its inputs." ) formatted_data = [self.serialize_data_format(d, data_format) for d in data] for i, input_id in enumerate(input_ids): diff --git a/test/container_test.py b/test/container_test.py index 8a305a477..d3361b8df 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -1110,20 +1110,20 @@ def _batch_function_test_helper(batch_func, servicer, args_list, expected_output @skip_github_non_linux def test_batch_sync_function_full_batch(servicer): - inputs = [((10, 5), {}) for _ in range(4)] + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}) for _ in range(4)] expected_outputs = [2] * 4 _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) @skip_github_non_linux def test_batch_sync_function_partial_batch(servicer): - inputs = [((10, 5), {}) for _ in range(2)] + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}) for _ in range(2)] expected_outputs = [2] * 2 _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) def test_batch_sync_function_keyword_args(servicer): - inputs = [((10,), {"y": 5}) for _ in range(4)] + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10,), {"y": 5}) for _ in range(4)] expected_outputs = [2] * 4 _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) @@ -1131,7 +1131,7 @@ def test_batch_sync_function_keyword_args(servicer): @skip_github_non_linux def test_batch_sync_function_inputs_outputs_error(servicer): # argument length does not match - inputs = [((10, 5), {}), ((10, 5, 1), {})] + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}), ((10, 5, 1), {})] with pytest.raises(InvalidError) as err: _batch_function_test_helper("batch_function_sync", servicer, inputs, []) assert "Modal batch function batch_function_sync takes 2 positional arguments, but one call has 3." in str(err) @@ -1159,22 +1159,21 @@ def test_batch_sync_function_inputs_outputs_error(servicer): with pytest.raises(InvalidError) as err: _batch_function_test_helper("batch_function_outputs_wrong_len", servicer, inputs, []) assert ( - "Output of batch function batch_function_outputs_wrong_len must be \ - a list of the same length as its list of inputs." + "Output of batch function batch_function_outputs_wrong_len must be a list of the same length as its inputs." in str(err) ) @skip_github_non_linux def test_batch_sync_function_generic_error(servicer): - inputs = [((10, 0), {}) for _ in range(4)] + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 0), {}) for _ in range(4)] expected_ouputs = ["ZeroDivisionError('division by zero')"] * 4 _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_ouputs, expected_status="failure") @skip_github_non_linux def test_batch_async_function(servicer): - inputs = [((10, 5), {}) for _ in range(4)] + inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10, 5), {}) for _ in range(4)] expected_outputs = [2] * 4 _batch_function_test_helper("batch_function_async", servicer, inputs, expected_outputs) From c71d9829f19a1ced9c0bf7d9f38881326af6cda9 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 31 Jul 2024 17:51:31 +0000 Subject: [PATCH 6/8] fix type check --- modal/_container_entrypoint.py | 2 +- modal/functions.py | 2 +- test/function_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index e47ea6b5f..f2a4a1e27 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -401,7 +401,7 @@ async def run_input_async( raise InvalidError(f"Async generator function returned value of type {type(res)}") if isinstance(input_ids, list) or isinstance(function_call_ids, list): raise InvalidError( - f"Batch function {finalized_function.callable.__name__} cannot return a generator." + f"Batch function {finalized_function.callable.__name__} cannot return generators." ) # Send up to this many outputs at a time. diff --git a/modal/functions.py b/modal/functions.py index 1d5dca34c..89423cee7 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -657,7 +657,7 @@ def from_args( else: raise InvalidError("Webhooks cannot be generators") if is_generator and batch_max_size: - raise InvalidError(f"Batch functions {info.raw_f.__name__} cannot return a generator") + raise InvalidError("Batch functions cannot return generators") if container_idle_timeout is not None and container_idle_timeout <= 0: raise InvalidError("`container_idle_timeout` must be > 0") diff --git a/test/function_test.py b/test/function_test.py index d73b351c5..59db9d675 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -819,7 +819,7 @@ def test_batch_function_invalid_error(): with pytest.raises(InvalidError, match="must be less than"): app.function(batch(batch_max_size=1, batch_linger_ms=10 * 60 * 1000))(dummy) - with pytest.raises(InvalidError, match="cannot return a generator"): + with pytest.raises(InvalidError, match="cannot return generators"): @app.function(serialized=True) @batch(batch_max_size=1, batch_linger_ms=1) From 15be06860c6832c8b603f13a110ab0478c1d37c2 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 31 Jul 2024 19:27:05 +0000 Subject: [PATCH 7/8] fix test on linux --- test/container_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/container_test.py b/test/container_test.py index d3361b8df..b43539edb 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -1122,6 +1122,7 @@ def test_batch_sync_function_partial_batch(servicer): _batch_function_test_helper("batch_function_sync", servicer, inputs, expected_outputs) +@skip_github_non_linux def test_batch_sync_function_keyword_args(servicer): inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [((10,), {"y": 5}) for _ in range(4)] expected_outputs = [2] * 4 From 55af97345c7991a307ea34db79cb6936ae9d95a9 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 31 Jul 2024 20:51:00 +0000 Subject: [PATCH 8/8] isolate input/output change --- modal/_container_entrypoint.py | 172 ++++++------------------------- modal/_container_io_manager.py | 180 +++++++-------------------------- modal/execution_context.py | 13 +-- 3 files changed, 67 insertions(+), 298 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index f2a4a1e27..b467e9185 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -13,7 +13,7 @@ import typing from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple from google.protobuf.message import Message from synchronicity import Interface @@ -323,92 +323,30 @@ def _sigint_handler(): self.loop.remove_signal_handler(signal.SIGINT) -def _aggregate_args_and_kwargs( - local_inputs: Union[LocalInput, List[LocalInput]], - callable: Callable[..., Any], -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - if isinstance(local_inputs, list): - param_names = list(inspect.signature(callable).parameters.keys()) - for param in inspect.signature(callable).parameters.values(): - if param.default is not inspect.Parameter.empty: - raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") - - args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))] - for i, local_input in enumerate(local_inputs): - params_len = len(local_input.args) + len(local_input.kwargs) - if params_len != len(param_names): - raise InvalidError( - f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa - ) - for j, arg in enumerate(local_input.args): - args_by_inputs[i][param_names[j]] = arg - for k, v in local_input.kwargs.items(): - if k not in param_names: - raise InvalidError( - f"Modal batch function {callable.__name__} got an unexpected keyword argument {k} in one call." - ) - if k in args_by_inputs[i]: - raise InvalidError( - f"Modal batch function {callable.__name__} got multiple values for argument {k} in one call." - ) - args_by_inputs[i][k] = v - - formatted_kwargs = { - param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names - } - return tuple(), formatted_kwargs - - else: - return local_inputs.args, local_inputs.kwargs - - def call_function( user_code_event_loop: UserCodeEventLoop, container_io_manager: "modal._container_io_manager.ContainerIOManager", finalized_functions: Dict[str, FinalizedFunction], input_concurrency: int, - batch_max_size: Optional[int], - batch_linger_ms: Optional[int], ): - async def run_input_async( - finalized_function: FinalizedFunction, - local_inputs: Union[LocalInput, List[LocalInput]], - container_io_manager: "modal._container_io_manager.ContainerIOManager", - ) -> None: + async def run_input_async(finalized_function: FinalizedFunction, local_input: LocalInput) -> None: started_at = time.time() - input_ids: Union[str, List[str]] = ( - local_inputs.input_id - if isinstance(local_inputs, LocalInput) - else [local_input.input_id for local_input in local_inputs] - ) - function_call_ids: Union[str, List[str]] = ( - local_inputs.function_call_id - if isinstance(local_inputs, LocalInput) - else [local_input.function_call_id for local_input in local_inputs] - ) - args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) - reset_context = _set_current_context_ids(input_ids, function_call_ids) - async with container_io_manager.handle_input_exception.aio( - input_ids, started_at, finalized_function.callable.__name__ - ): - logger.debug(f"Starting input {input_ids} (async)") - res = finalized_function.callable(*args, **kwargs) - logger.debug(f"Finished input {input_ids} (async)") + reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id) + async with container_io_manager.handle_input_exception.aio(local_input.input_id, started_at): + logger.debug(f"Starting input {local_input.input_id} (async)") + res = finalized_function.callable(*local_input.args, **local_input.kwargs) + logger.debug(f"Finished input {local_input.input_id} (async)") # TODO(erikbern): any exception below shouldn't be considered a user exception if finalized_function.is_generator: if not inspect.isasyncgen(res): raise InvalidError(f"Async generator function returned value of type {type(res)}") - if isinstance(input_ids, list) or isinstance(function_call_ids, list): - raise InvalidError( - f"Batch function {finalized_function.callable.__name__} cannot return generators." - ) # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024) generator_output_task = asyncio.create_task( container_io_manager.generator_output_task.aio( - function_call_ids, + local_input.function_call_id, finalized_function.data_format, generator_queue, ) @@ -423,11 +361,7 @@ async def run_input_async( await generator_output_task # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) await container_io_manager.push_output.aio( - input_ids, - started_at, - finalized_function.callable.__name__, - message, - api_pb2.DATA_FORMAT_GENERATOR_DONE, + local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE ) else: if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): @@ -437,46 +371,27 @@ async def run_input_async( ) value = await res await container_io_manager.push_output.aio( - input_ids, started_at, finalized_function.callable.__name__, value, finalized_function.data_format + local_input.input_id, started_at, value, finalized_function.data_format ) reset_context() - def run_input_sync( - finalized_function: FinalizedFunction, - local_inputs: Union[LocalInput, List[LocalInput]], - container_io_manager: "modal._container_io_manager.ContainerIOManager", - ) -> None: + def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInput) -> None: started_at = time.time() - input_ids: Union[str, List[str]] = ( - local_inputs.input_id - if isinstance(local_inputs, LocalInput) - else [local_input.input_id for local_input in local_inputs] - ) - function_call_ids: Union[str, List[str]] = ( - local_inputs.function_call_id - if isinstance(local_inputs, LocalInput) - else [local_input.function_call_id for local_input in local_inputs] - ) - args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) - reset_context = _set_current_context_ids(input_ids, function_call_ids) - with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__): - logger.debug(f"Starting input {input_ids} (sync)") - res = finalized_function.callable(*args, **kwargs) - logger.debug(f"Finished input {input_ids} (sync)") + reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id) + with container_io_manager.handle_input_exception(local_input.input_id, started_at): + logger.debug(f"Starting input {local_input.input_id} (sync)") + res = finalized_function.callable(*local_input.args, **local_input.kwargs) + logger.debug(f"Finished input {local_input.input_id} (sync)") # TODO(erikbern): any exception below shouldn't be considered a user exception if finalized_function.is_generator: if not inspect.isgenerator(res): raise InvalidError(f"Generator function returned value of type {type(res)}") - if isinstance(function_call_ids, list): - raise InvalidError( - f"Batch function {finalized_function.callable.__name__} cannot return generators." - ) # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore - function_call_ids, + local_input.function_call_id, finalized_function.data_format, generator_queue, _future=True, # type: ignore # Synchronicity magic to return a future. @@ -491,11 +406,7 @@ def run_input_sync( generator_output_task.result() # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) container_io_manager.push_output( - input_ids, - started_at, - finalized_function.callable.__name__, - message, - api_pb2.DATA_FORMAT_GENERATOR_DONE, + local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE ) else: if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): @@ -503,19 +414,9 @@ def run_input_sync( f"Sync (non-generator) function return value of type {type(res)}." " You might need to use @app.function(..., is_generator=True)." ) - container_io_manager.push_output( - input_ids, started_at, finalized_function.callable.__name__, res, finalized_function.data_format - ) + container_io_manager.push_output(local_input.input_id, started_at, res, finalized_function.data_format) reset_context() - def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction: - if isinstance(local_inputs, list): - assert len(local_inputs) > 0 - assert len(set([local_input.method_name for local_input in local_inputs])) == 1 - return finalized_functions[local_inputs[0].method_name] - else: - return finalized_functions[local_inputs.method_name] - if input_concurrency > 1: with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool: @@ -524,28 +425,24 @@ async def run_concurrent_inputs(): # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s # for them to resolve gracefully: async with TaskContext(0.01) as task_context: - async for local_inputs in container_io_manager.run_inputs_outputs.aio( - input_concurrency, batch_max_size, batch_linger_ms - ): - finalized_function = _get_finalized_functions(local_inputs) + async for local_input in container_io_manager.run_inputs_outputs.aio(input_concurrency): + finalized_function = finalized_functions[local_input.method_name] # Note that run_inputs_outputs will not return until the concurrency semaphore has # released all its slots so that they can be acquired by the run_inputs_outputs finalizer # This prevents leaving the task_context before outputs have been created # TODO: refactor to make this a bit more easy to follow? if finalized_function.is_async: - task_context.create_task( - run_input_async(finalized_function, local_inputs, container_io_manager) - ) + task_context.create_task(run_input_async(finalized_function, local_input)) else: # run sync input in thread - thread_pool.submit(run_input_sync, finalized_function, local_inputs, container_io_manager) + thread_pool.submit(run_input_sync, finalized_function, local_input) user_code_event_loop.run(run_concurrent_inputs()) else: - for local_inputs in container_io_manager.run_inputs_outputs(input_concurrency, batch_max_size, batch_linger_ms): - finalized_function = _get_finalized_functions(local_inputs) + for local_input in container_io_manager.run_inputs_outputs(input_concurrency): + finalized_function = finalized_functions[local_input.method_name] if finalized_function.is_async: - user_code_event_loop.run(run_input_async(finalized_function, local_inputs, container_io_manager)) + user_code_event_loop.run(run_input_async(finalized_function, local_input)) else: # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation # during function execution. This is sent to cancel inputs from the user @@ -556,7 +453,7 @@ def _cancel_input_signal_handler(signum, stackframe): # run this sync code in the main thread, blocking the "userland" event loop # this lets us cancel it using a signal handler that raises an exception try: - run_input_sync(finalized_function, local_inputs, container_io_manager) + run_input_sync(finalized_function, local_input) finally: signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler @@ -841,14 +738,10 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): # Container can fetch multiple inputs simultaneously if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL: - # Concurrency and batching doesn't apply for `modal shell`. + # Concurrency doesn't apply for `modal shell`. input_concurrency = 1 - batch_max_size = 0 - batch_linger_ms = 0 else: input_concurrency = function_def.allow_concurrent_inputs or 1 - batch_max_size = function_def.batch_max_size or 0 - batch_linger_ms = function_def.batch_linger_ms or 0 # Get ids and metadata for objects (primarily functions and classes) on the app container_app: RunningApp = container_io_manager.get_app_objects() @@ -910,14 +803,7 @@ def breakpoint_wrapper(): # Execute the function. try: - call_function( - event_loop, - container_io_manager, - finalized_functions, - input_concurrency, - batch_max_size, - batch_linger_ms, - ) + call_function(event_loop, container_io_manager, finalized_functions, input_concurrency) finally: # Run exit handlers. From this point onward, ignore all SIGINT signals that come from # graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 65f9c9d44..ded25a955 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -9,7 +9,7 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union +from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -47,8 +47,8 @@ class LocalInput: input_id: str function_call_id: str method_name: str - args: Tuple[Any, ...] - kwargs: Dict[str, Any] + args: Any + kwargs: Any class _ContainerIOManager: @@ -350,9 +350,7 @@ def get_max_inputs_to_fetch(self): return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6)) @synchronizer.no_io_translation - async def _generate_inputs( - self, - ) -> AsyncIterator[Union[Tuple[str, str, api_pb2.FunctionInput], List[Tuple[str, str, api_pb2.FunctionInput]]]]: + async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]: request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id) eof_received = False iteration = 0 @@ -360,8 +358,6 @@ async def _generate_inputs( request.average_call_time = self.get_average_call_time() request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove. request.input_concurrency = self._input_concurrency - request.batch_max_size = self._batch_max_size - request.batch_linger_ms = self._batch_linger_ms await self._semaphore.acquire() yielded = False @@ -379,11 +375,11 @@ async def _generate_inputs( ) await asyncio.sleep(response.rate_limit_sleep_duration) elif response.inputs: - if self._batch_max_size == 0: - # for input cancellations and concurrency logic we currently assume - # that there is no input buffering in the container - assert len(response.inputs) == 1 - item = response.inputs[0] + # for input cancellations and concurrency logic we currently assume + # that there is no input buffering in the container + assert len(response.inputs) == 1 + + for item in response.inputs: if item.kill_switch: logger.debug(f"Task {self.task_id} input kill signal input.") eof_received = True @@ -405,144 +401,48 @@ async def _generate_inputs( if item.input.final_input or self.function_def.max_inputs == 1: eof_received = True break - else: - assert len(response.inputs) <= request.batch_max_size - - inputs_list = [] - for item in response.inputs: - if item.kill_switch: - assert len(response.inputs) == 1 - logger.debug(f"Task {self.task_id} input kill signal input.") - eof_received = True - break - assert item.input_id not in self.cancelled_input_ids - - # If we got a pointer to a blob, download it from S3. - if item.input.WhichOneof("args_oneof") == "args_blob_id": - input_pb = await self.populate_input_blobs(item.input) - else: - input_pb = item.input - - inputs_list.append((item.input_id, item.function_call_id, input_pb)) - if item.input.final_input: - eof_received = True - logger.error("Final input not expected in batch input stream") - break - - if not eof_received: - yield inputs_list - yielded = True - - # We only support max_inputs = 1 at the moment - if self.function_def.max_inputs == 1: - eof_received = True - finally: if not yielded: self._semaphore.release() @synchronizer.no_io_translation - async def run_inputs_outputs( - self, - input_concurrency: int = 1, - batch_max_size: int = 0, - batch_linger_ms: int = 0, - ) -> AsyncIterator[Union[LocalInput, List[LocalInput]]]: + async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[LocalInput]: # Ensure we do not fetch new inputs when container is too busy. # Before trying to fetch an input, acquire the semaphore: # - if no input is fetched, release the semaphore. # - or, when the output for the fetched input is sent, release the semaphore. self._input_concurrency = input_concurrency - self._batch_max_size = batch_max_size - self._batch_linger_ms = batch_linger_ms self._semaphore = asyncio.Semaphore(input_concurrency) - if batch_max_size == 0: - async for input_id, function_call_id, input_pb in self._generate_inputs(): - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) - self.current_input_id, self.current_input_started_at = (None, None) - else: - async for inputs_list in self._generate_inputs(): - local_inputs_list = [] - for input_id, function_call_id, input_pb in inputs_list: - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) - yield local_inputs_list - self.current_input_id, self.current_input_started_at = (None, None) + async for input_id, function_call_id, input_pb in self._generate_inputs(): + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) + self.current_input_id, self.current_input_started_at = (None, None) # collect all active input slots, meaning all inputs have wrapped up. for _ in range(input_concurrency): await self._semaphore.acquire() - async def _push_output( - self, - input_ids: Union[str, List[str]], - started_at: float, - function_name: str, - data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, - **kwargs, - ): - outputs = [] - if isinstance(input_ids, list): - formatted_data = None - if "data" in kwargs and kwargs["data"]: - # split the list of data in kwargs to respective input_ids - # report error for every input_id in batch call - if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: - formatted_data = [kwargs.pop("data")] * len(input_ids) - else: - data = self.deserialize_data_format(kwargs.pop("data"), data_format) - if not isinstance(data, list): - raise InvalidError(f"Output of batch function {function_name} must be a list.") - if len(data) != len(input_ids): - raise InvalidError( - f"Output of batch function {function_name} must be a list of the same length as its inputs." - ) - formatted_data = [self.serialize_data_format(d, data_format) for d in data] - for i, input_id in enumerate(input_ids): - data = formatted_data[i] if formatted_data else None - result = ( - ( - # upload data to S3 if too big. - api_pb2.GenericResult(data_blob_id=await blob_upload(data, self._client.stub), **kwargs) - if len(data) > MAX_OBJECT_SIZE_BYTES - else api_pb2.GenericResult(data=data, **kwargs) - ) - if data - else api_pb2.GenericResult(**kwargs) - ) - outputs.append( - api_pb2.FunctionPutOutputsItem( - input_id=input_id, - input_started_at=started_at, - output_created_at=time.time(), - result=result, - data_format=data_format, - ) - ) - else: - # upload data to S3 if too big. - if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: - data_blob_id = await blob_upload(kwargs["data"], self._client.stub) - # mutating kwargs. - del kwargs["data"] - kwargs["data_blob_id"] = data_blob_id - - outputs.append( - api_pb2.FunctionPutOutputsItem( - input_id=input_ids, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - ) + async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs): + # upload data to S3 if too big. + if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: + data_blob_id = await blob_upload(kwargs["data"], self._client.stub) + # mutating kwargs. + del kwargs["data"] + kwargs["data_blob_id"] = data_blob_id + + output = api_pb2.FunctionPutOutputsItem( + input_id=input_id, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, + ) + await retry_transient_errors( self._client.stub.FunctionPutOutputs, - api_pb2.FunctionPutOutputsRequest(outputs=outputs), + api_pb2.FunctionPutOutputsRequest(outputs=[output]), additional_status_codes=[Status.RESOURCE_EXHAUSTED], max_retries=None, # Retry indefinitely, trying every 1s. ) @@ -602,9 +502,7 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: raise UserException() @asynccontextmanager - async def handle_input_exception( - self, input_ids: Union[str, List[str]], started_at: float, function_name: str - ) -> AsyncGenerator[None, None]: + async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]: """Handle an exception while processing a function input.""" try: yield @@ -619,13 +517,9 @@ async def handle_input_exception( # just skip creating any output for this input and keep going with the next instead # it should have been marked as cancelled already in the backend at this point so it # won't be retried - logger.warning(f"The current input ({input_ids=}) was cancelled by a user request") + logger.warning(f"The current input ({input_id=}) was cancelled by a user request") await self.complete_call(started_at) return - except InvalidError as exc: - # If there is an error in batch function output, we need to explicitly raise it - if "Output of batch function" in exc.args[0]: - raise except BaseException as exc: # print exception so it's logged traceback.print_exc() @@ -647,9 +541,8 @@ async def handle_input_exception( repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" await self._push_output( - input_ids, + input_id, started_at=started_at, - function_name=function_name, data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, data=self.serialize_exception(exc), @@ -666,11 +559,10 @@ async def complete_call(self, started_at): self._semaphore.release() @synchronizer.no_io_translation - async def push_output(self, input_id, started_at: float, function_name: str, data: Any, data_format: int) -> None: + async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None: await self._push_output( input_id, started_at=started_at, - function_name=function_name, data_format=data_format, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, data=self.serialize_data_format(data, data_format), diff --git a/modal/execution_context.py b/modal/execution_context.py index bf1bd399d..764d85cf8 100644 --- a/modal/execution_context.py +++ b/modal/execution_context.py @@ -1,6 +1,6 @@ # Copyright Modal Labs 2024 from contextvars import ContextVar -from typing import Callable, List, Optional, Union +from typing import Callable, Optional from modal._container_io_manager import _ContainerIOManager from modal._utils.async_utils import synchronize_api @@ -70,16 +70,7 @@ def process_stuff(): return None -def _set_current_context_ids( - input_ids: Union[str, List[str]], function_call_ids: Union[str, List[str]] -) -> Callable[[], None]: - if isinstance(input_ids, list): - assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0 - input_id = input_ids[0] - function_call_id = function_call_ids[0] - else: - input_id = input_ids - function_call_id = function_call_ids +def _set_current_context_ids(input_id: str, function_call_id: str) -> Callable[[], None]: input_token = _current_input_id.set(input_id) function_call_token = _current_function_call_id.set(function_call_id)