From 7d9896edab7900f8407e7978d176ac213f1928a9 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 31 Jul 2024 20:46:47 +0000 Subject: [PATCH 01/23] process batched inputs on client side --- modal/_container_entrypoint.py | 172 +++++++++++++++++++++++++------ modal/_container_io_manager.py | 180 ++++++++++++++++++++++++++------- modal/execution_context.py | 13 ++- 3 files changed, 298 insertions(+), 67 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index b467e9185..f2a4a1e27 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -13,7 +13,7 @@ import typing from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from google.protobuf.message import Message from synchronicity import Interface @@ -323,30 +323,92 @@ def _sigint_handler(): self.loop.remove_signal_handler(signal.SIGINT) +def _aggregate_args_and_kwargs( + local_inputs: Union[LocalInput, List[LocalInput]], + callable: Callable[..., Any], +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + if isinstance(local_inputs, list): + param_names = list(inspect.signature(callable).parameters.keys()) + for param in inspect.signature(callable).parameters.values(): + if param.default is not inspect.Parameter.empty: + raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") + + args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))] + for i, local_input in enumerate(local_inputs): + params_len = len(local_input.args) + len(local_input.kwargs) + if params_len != len(param_names): + raise InvalidError( + f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa + ) + for j, arg in enumerate(local_input.args): + args_by_inputs[i][param_names[j]] = arg + for k, v in local_input.kwargs.items(): + if k not in param_names: + raise InvalidError( + f"Modal batch function {callable.__name__} got an unexpected keyword argument {k} in one call." + ) + if k in args_by_inputs[i]: + raise InvalidError( + f"Modal batch function {callable.__name__} got multiple values for argument {k} in one call." + ) + args_by_inputs[i][k] = v + + formatted_kwargs = { + param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names + } + return tuple(), formatted_kwargs + + else: + return local_inputs.args, local_inputs.kwargs + + def call_function( user_code_event_loop: UserCodeEventLoop, container_io_manager: "modal._container_io_manager.ContainerIOManager", finalized_functions: Dict[str, FinalizedFunction], input_concurrency: int, + batch_max_size: Optional[int], + batch_linger_ms: Optional[int], ): - async def run_input_async(finalized_function: FinalizedFunction, local_input: LocalInput) -> None: + async def run_input_async( + finalized_function: FinalizedFunction, + local_inputs: Union[LocalInput, List[LocalInput]], + container_io_manager: "modal._container_io_manager.ContainerIOManager", + ) -> None: started_at = time.time() - reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id) - async with container_io_manager.handle_input_exception.aio(local_input.input_id, started_at): - logger.debug(f"Starting input {local_input.input_id} (async)") - res = finalized_function.callable(*local_input.args, **local_input.kwargs) - logger.debug(f"Finished input {local_input.input_id} (async)") + input_ids: Union[str, List[str]] = ( + local_inputs.input_id + if isinstance(local_inputs, LocalInput) + else [local_input.input_id for local_input in local_inputs] + ) + function_call_ids: Union[str, List[str]] = ( + local_inputs.function_call_id + if isinstance(local_inputs, LocalInput) + else [local_input.function_call_id for local_input in local_inputs] + ) + args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) + reset_context = _set_current_context_ids(input_ids, function_call_ids) + async with container_io_manager.handle_input_exception.aio( + input_ids, started_at, finalized_function.callable.__name__ + ): + logger.debug(f"Starting input {input_ids} (async)") + res = finalized_function.callable(*args, **kwargs) + logger.debug(f"Finished input {input_ids} (async)") # TODO(erikbern): any exception below shouldn't be considered a user exception if finalized_function.is_generator: if not inspect.isasyncgen(res): raise InvalidError(f"Async generator function returned value of type {type(res)}") + if isinstance(input_ids, list) or isinstance(function_call_ids, list): + raise InvalidError( + f"Batch function {finalized_function.callable.__name__} cannot return generators." + ) # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024) generator_output_task = asyncio.create_task( container_io_manager.generator_output_task.aio( - local_input.function_call_id, + function_call_ids, finalized_function.data_format, generator_queue, ) @@ -361,7 +423,11 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo await generator_output_task # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) await container_io_manager.push_output.aio( - local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE + input_ids, + started_at, + finalized_function.callable.__name__, + message, + api_pb2.DATA_FORMAT_GENERATOR_DONE, ) else: if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): @@ -371,27 +437,46 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo ) value = await res await container_io_manager.push_output.aio( - local_input.input_id, started_at, value, finalized_function.data_format + input_ids, started_at, finalized_function.callable.__name__, value, finalized_function.data_format ) reset_context() - def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInput) -> None: + def run_input_sync( + finalized_function: FinalizedFunction, + local_inputs: Union[LocalInput, List[LocalInput]], + container_io_manager: "modal._container_io_manager.ContainerIOManager", + ) -> None: started_at = time.time() - reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id) - with container_io_manager.handle_input_exception(local_input.input_id, started_at): - logger.debug(f"Starting input {local_input.input_id} (sync)") - res = finalized_function.callable(*local_input.args, **local_input.kwargs) - logger.debug(f"Finished input {local_input.input_id} (sync)") + input_ids: Union[str, List[str]] = ( + local_inputs.input_id + if isinstance(local_inputs, LocalInput) + else [local_input.input_id for local_input in local_inputs] + ) + function_call_ids: Union[str, List[str]] = ( + local_inputs.function_call_id + if isinstance(local_inputs, LocalInput) + else [local_input.function_call_id for local_input in local_inputs] + ) + args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) + reset_context = _set_current_context_ids(input_ids, function_call_ids) + with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__): + logger.debug(f"Starting input {input_ids} (sync)") + res = finalized_function.callable(*args, **kwargs) + logger.debug(f"Finished input {input_ids} (sync)") # TODO(erikbern): any exception below shouldn't be considered a user exception if finalized_function.is_generator: if not inspect.isgenerator(res): raise InvalidError(f"Generator function returned value of type {type(res)}") + if isinstance(function_call_ids, list): + raise InvalidError( + f"Batch function {finalized_function.callable.__name__} cannot return generators." + ) # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore - local_input.function_call_id, + function_call_ids, finalized_function.data_format, generator_queue, _future=True, # type: ignore # Synchronicity magic to return a future. @@ -406,7 +491,11 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu generator_output_task.result() # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) container_io_manager.push_output( - local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE + input_ids, + started_at, + finalized_function.callable.__name__, + message, + api_pb2.DATA_FORMAT_GENERATOR_DONE, ) else: if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): @@ -414,9 +503,19 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu f"Sync (non-generator) function return value of type {type(res)}." " You might need to use @app.function(..., is_generator=True)." ) - container_io_manager.push_output(local_input.input_id, started_at, res, finalized_function.data_format) + container_io_manager.push_output( + input_ids, started_at, finalized_function.callable.__name__, res, finalized_function.data_format + ) reset_context() + def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction: + if isinstance(local_inputs, list): + assert len(local_inputs) > 0 + assert len(set([local_input.method_name for local_input in local_inputs])) == 1 + return finalized_functions[local_inputs[0].method_name] + else: + return finalized_functions[local_inputs.method_name] + if input_concurrency > 1: with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool: @@ -425,24 +524,28 @@ async def run_concurrent_inputs(): # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s # for them to resolve gracefully: async with TaskContext(0.01) as task_context: - async for local_input in container_io_manager.run_inputs_outputs.aio(input_concurrency): - finalized_function = finalized_functions[local_input.method_name] + async for local_inputs in container_io_manager.run_inputs_outputs.aio( + input_concurrency, batch_max_size, batch_linger_ms + ): + finalized_function = _get_finalized_functions(local_inputs) # Note that run_inputs_outputs will not return until the concurrency semaphore has # released all its slots so that they can be acquired by the run_inputs_outputs finalizer # This prevents leaving the task_context before outputs have been created # TODO: refactor to make this a bit more easy to follow? if finalized_function.is_async: - task_context.create_task(run_input_async(finalized_function, local_input)) + task_context.create_task( + run_input_async(finalized_function, local_inputs, container_io_manager) + ) else: # run sync input in thread - thread_pool.submit(run_input_sync, finalized_function, local_input) + thread_pool.submit(run_input_sync, finalized_function, local_inputs, container_io_manager) user_code_event_loop.run(run_concurrent_inputs()) else: - for local_input in container_io_manager.run_inputs_outputs(input_concurrency): - finalized_function = finalized_functions[local_input.method_name] + for local_inputs in container_io_manager.run_inputs_outputs(input_concurrency, batch_max_size, batch_linger_ms): + finalized_function = _get_finalized_functions(local_inputs) if finalized_function.is_async: - user_code_event_loop.run(run_input_async(finalized_function, local_input)) + user_code_event_loop.run(run_input_async(finalized_function, local_inputs, container_io_manager)) else: # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation # during function execution. This is sent to cancel inputs from the user @@ -453,7 +556,7 @@ def _cancel_input_signal_handler(signum, stackframe): # run this sync code in the main thread, blocking the "userland" event loop # this lets us cancel it using a signal handler that raises an exception try: - run_input_sync(finalized_function, local_input) + run_input_sync(finalized_function, local_inputs, container_io_manager) finally: signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler @@ -738,10 +841,14 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): # Container can fetch multiple inputs simultaneously if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL: - # Concurrency doesn't apply for `modal shell`. + # Concurrency and batching doesn't apply for `modal shell`. input_concurrency = 1 + batch_max_size = 0 + batch_linger_ms = 0 else: input_concurrency = function_def.allow_concurrent_inputs or 1 + batch_max_size = function_def.batch_max_size or 0 + batch_linger_ms = function_def.batch_linger_ms or 0 # Get ids and metadata for objects (primarily functions and classes) on the app container_app: RunningApp = container_io_manager.get_app_objects() @@ -803,7 +910,14 @@ def breakpoint_wrapper(): # Execute the function. try: - call_function(event_loop, container_io_manager, finalized_functions, input_concurrency) + call_function( + event_loop, + container_io_manager, + finalized_functions, + input_concurrency, + batch_max_size, + batch_linger_ms, + ) finally: # Run exit handlers. From this point onward, ignore all SIGINT signals that come from # graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index ded25a955..65f9c9d44 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -9,7 +9,7 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, List, Optional, Set, Tuple +from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -47,8 +47,8 @@ class LocalInput: input_id: str function_call_id: str method_name: str - args: Any - kwargs: Any + args: Tuple[Any, ...] + kwargs: Dict[str, Any] class _ContainerIOManager: @@ -350,7 +350,9 @@ def get_max_inputs_to_fetch(self): return math.ceil(RTT_S / max(self.get_average_call_time(), 1e-6)) @synchronizer.no_io_translation - async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.FunctionInput]]: + async def _generate_inputs( + self, + ) -> AsyncIterator[Union[Tuple[str, str, api_pb2.FunctionInput], List[Tuple[str, str, api_pb2.FunctionInput]]]]: request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id) eof_received = False iteration = 0 @@ -358,6 +360,8 @@ async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.Functi request.average_call_time = self.get_average_call_time() request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove. request.input_concurrency = self._input_concurrency + request.batch_max_size = self._batch_max_size + request.batch_linger_ms = self._batch_linger_ms await self._semaphore.acquire() yielded = False @@ -375,11 +379,11 @@ async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.Functi ) await asyncio.sleep(response.rate_limit_sleep_duration) elif response.inputs: - # for input cancellations and concurrency logic we currently assume - # that there is no input buffering in the container - assert len(response.inputs) == 1 - - for item in response.inputs: + if self._batch_max_size == 0: + # for input cancellations and concurrency logic we currently assume + # that there is no input buffering in the container + assert len(response.inputs) == 1 + item = response.inputs[0] if item.kill_switch: logger.debug(f"Task {self.task_id} input kill signal input.") eof_received = True @@ -401,48 +405,144 @@ async def _generate_inputs(self) -> AsyncIterator[Tuple[str, str, api_pb2.Functi if item.input.final_input or self.function_def.max_inputs == 1: eof_received = True break + else: + assert len(response.inputs) <= request.batch_max_size + + inputs_list = [] + for item in response.inputs: + if item.kill_switch: + assert len(response.inputs) == 1 + logger.debug(f"Task {self.task_id} input kill signal input.") + eof_received = True + break + assert item.input_id not in self.cancelled_input_ids + + # If we got a pointer to a blob, download it from S3. + if item.input.WhichOneof("args_oneof") == "args_blob_id": + input_pb = await self.populate_input_blobs(item.input) + else: + input_pb = item.input + + inputs_list.append((item.input_id, item.function_call_id, input_pb)) + if item.input.final_input: + eof_received = True + logger.error("Final input not expected in batch input stream") + break + + if not eof_received: + yield inputs_list + yielded = True + + # We only support max_inputs = 1 at the moment + if self.function_def.max_inputs == 1: + eof_received = True + finally: if not yielded: self._semaphore.release() @synchronizer.no_io_translation - async def run_inputs_outputs(self, input_concurrency: int = 1) -> AsyncIterator[LocalInput]: + async def run_inputs_outputs( + self, + input_concurrency: int = 1, + batch_max_size: int = 0, + batch_linger_ms: int = 0, + ) -> AsyncIterator[Union[LocalInput, List[LocalInput]]]: # Ensure we do not fetch new inputs when container is too busy. # Before trying to fetch an input, acquire the semaphore: # - if no input is fetched, release the semaphore. # - or, when the output for the fetched input is sent, release the semaphore. self._input_concurrency = input_concurrency + self._batch_max_size = batch_max_size + self._batch_linger_ms = batch_linger_ms self._semaphore = asyncio.Semaphore(input_concurrency) - async for input_id, function_call_id, input_pb in self._generate_inputs(): - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) - self.current_input_id, self.current_input_started_at = (None, None) + if batch_max_size == 0: + async for input_id, function_call_id, input_pb in self._generate_inputs(): + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) + self.current_input_id, self.current_input_started_at = (None, None) + else: + async for inputs_list in self._generate_inputs(): + local_inputs_list = [] + for input_id, function_call_id, input_pb in inputs_list: + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) + yield local_inputs_list + self.current_input_id, self.current_input_started_at = (None, None) # collect all active input slots, meaning all inputs have wrapped up. for _ in range(input_concurrency): await self._semaphore.acquire() - async def _push_output(self, input_id, started_at: float, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs): - # upload data to S3 if too big. - if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: - data_blob_id = await blob_upload(kwargs["data"], self._client.stub) - # mutating kwargs. - del kwargs["data"] - kwargs["data_blob_id"] = data_blob_id - - output = api_pb2.FunctionPutOutputsItem( - input_id=input_id, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - + async def _push_output( + self, + input_ids: Union[str, List[str]], + started_at: float, + function_name: str, + data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, + **kwargs, + ): + outputs = [] + if isinstance(input_ids, list): + formatted_data = None + if "data" in kwargs and kwargs["data"]: + # split the list of data in kwargs to respective input_ids + # report error for every input_id in batch call + if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: + formatted_data = [kwargs.pop("data")] * len(input_ids) + else: + data = self.deserialize_data_format(kwargs.pop("data"), data_format) + if not isinstance(data, list): + raise InvalidError(f"Output of batch function {function_name} must be a list.") + if len(data) != len(input_ids): + raise InvalidError( + f"Output of batch function {function_name} must be a list of the same length as its inputs." + ) + formatted_data = [self.serialize_data_format(d, data_format) for d in data] + for i, input_id in enumerate(input_ids): + data = formatted_data[i] if formatted_data else None + result = ( + ( + # upload data to S3 if too big. + api_pb2.GenericResult(data_blob_id=await blob_upload(data, self._client.stub), **kwargs) + if len(data) > MAX_OBJECT_SIZE_BYTES + else api_pb2.GenericResult(data=data, **kwargs) + ) + if data + else api_pb2.GenericResult(**kwargs) + ) + outputs.append( + api_pb2.FunctionPutOutputsItem( + input_id=input_id, + input_started_at=started_at, + output_created_at=time.time(), + result=result, + data_format=data_format, + ) + ) + else: + # upload data to S3 if too big. + if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: + data_blob_id = await blob_upload(kwargs["data"], self._client.stub) + # mutating kwargs. + del kwargs["data"] + kwargs["data_blob_id"] = data_blob_id + + outputs.append( + api_pb2.FunctionPutOutputsItem( + input_id=input_ids, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, + ) + ) await retry_transient_errors( self._client.stub.FunctionPutOutputs, - api_pb2.FunctionPutOutputsRequest(outputs=[output]), + api_pb2.FunctionPutOutputsRequest(outputs=outputs), additional_status_codes=[Status.RESOURCE_EXHAUSTED], max_retries=None, # Retry indefinitely, trying every 1s. ) @@ -502,7 +602,9 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: raise UserException() @asynccontextmanager - async def handle_input_exception(self, input_id, started_at: float) -> AsyncGenerator[None, None]: + async def handle_input_exception( + self, input_ids: Union[str, List[str]], started_at: float, function_name: str + ) -> AsyncGenerator[None, None]: """Handle an exception while processing a function input.""" try: yield @@ -517,9 +619,13 @@ async def handle_input_exception(self, input_id, started_at: float) -> AsyncGene # just skip creating any output for this input and keep going with the next instead # it should have been marked as cancelled already in the backend at this point so it # won't be retried - logger.warning(f"The current input ({input_id=}) was cancelled by a user request") + logger.warning(f"The current input ({input_ids=}) was cancelled by a user request") await self.complete_call(started_at) return + except InvalidError as exc: + # If there is an error in batch function output, we need to explicitly raise it + if "Output of batch function" in exc.args[0]: + raise except BaseException as exc: # print exception so it's logged traceback.print_exc() @@ -541,8 +647,9 @@ async def handle_input_exception(self, input_id, started_at: float) -> AsyncGene repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" await self._push_output( - input_id, + input_ids, started_at=started_at, + function_name=function_name, data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, data=self.serialize_exception(exc), @@ -559,10 +666,11 @@ async def complete_call(self, started_at): self._semaphore.release() @synchronizer.no_io_translation - async def push_output(self, input_id, started_at: float, data: Any, data_format: int) -> None: + async def push_output(self, input_id, started_at: float, function_name: str, data: Any, data_format: int) -> None: await self._push_output( input_id, started_at=started_at, + function_name=function_name, data_format=data_format, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, data=self.serialize_data_format(data, data_format), diff --git a/modal/execution_context.py b/modal/execution_context.py index 764d85cf8..bf1bd399d 100644 --- a/modal/execution_context.py +++ b/modal/execution_context.py @@ -1,6 +1,6 @@ # Copyright Modal Labs 2024 from contextvars import ContextVar -from typing import Callable, Optional +from typing import Callable, List, Optional, Union from modal._container_io_manager import _ContainerIOManager from modal._utils.async_utils import synchronize_api @@ -70,7 +70,16 @@ def process_stuff(): return None -def _set_current_context_ids(input_id: str, function_call_id: str) -> Callable[[], None]: +def _set_current_context_ids( + input_ids: Union[str, List[str]], function_call_ids: Union[str, List[str]] +) -> Callable[[], None]: + if isinstance(input_ids, list): + assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0 + input_id = input_ids[0] + function_call_id = function_call_ids[0] + else: + input_id = input_ids + function_call_id = function_call_ids input_token = _current_input_id.set(input_id) function_call_token = _current_function_call_id.set(function_call_id) From c1cec07971a2cb3cb3f25f7f5da3c2f76f4236f2 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 1 Aug 2024 18:14:04 +0000 Subject: [PATCH 02/23] upload blob data without blocking --- modal/_container_io_manager.py | 126 ++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 59 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 65f9c9d44..6a9c94ee0 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -7,7 +7,6 @@ import sys import time import traceback -from dataclasses import dataclass from pathlib import Path from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union @@ -42,13 +41,24 @@ class Sentinel: """Used to get type-stubs to work with this object.""" -@dataclass class LocalInput: - input_id: str - function_call_id: str - method_name: str - args: Tuple[Any, ...] - kwargs: Dict[str, Any] + def __init__( + self, + container_io_manager: "_ContainerIOManager", + input_id: str, + function_call_id: str, + input_pb: api_pb2.FunctionInput, + ): + self.input_id: str = input_id + self.function_call_id: str = function_call_id + self.method_name: str = input_pb.method_name + + args, kwargs = container_io_manager.deserialize(input_pb.args) if input_pb.args else ((), {}) + self.args: Tuple[Any, ...] = args + self.kwargs: Dict[str, Any] = kwargs + + container_io_manager.current_input_id = input_id + container_io_manager.current_input_started_at = time.time() class _ContainerIOManager: @@ -459,17 +469,13 @@ async def run_inputs_outputs( if batch_max_size == 0: async for input_id, function_call_id, input_pb in self._generate_inputs(): - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - yield LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs) + yield LocalInput(self, input_id, function_call_id, input_pb) self.current_input_id, self.current_input_started_at = (None, None) else: async for inputs_list in self._generate_inputs(): local_inputs_list = [] for input_id, function_call_id, input_pb in inputs_list: - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) + local_inputs_list.append(LocalInput(self, input_id, function_call_id, input_pb)) yield local_inputs_list self.current_input_id, self.current_input_started_at = (None, None) @@ -477,6 +483,41 @@ async def run_inputs_outputs( for _ in range(input_concurrency): await self._semaphore.acquire() + async def _get_kwargs_process_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: + if len(data) > MAX_OBJECT_SIZE_BYTES: + kwargs["data_blob_id"] = await blob_upload(data, self._client.stub) + else: + kwargs["data"] = data + return kwargs + + async def _get_kwargs( + self, kwargs: Dict[str, Any], input_ids: Union[str, List[str]], function_name: str, data_format: int + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + if "data" not in kwargs or not kwargs["data"]: + return kwargs if isinstance(input_ids, str) else [kwargs] * len(input_ids) + # data is not batched, return a single kwargs. + if isinstance(input_ids, str): + return await self._get_kwargs_process_blob_data(kwargs.pop("data"), kwargs) + # data is batched, return a list of kwargs + # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. + if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: + error_data = kwargs.pop("data") + return [await self._get_kwargs_process_blob_data(error_data, kwargs) for _ in input_ids] + else: + data = self.deserialize_data_format(kwargs.pop("data"), data_format) + if not isinstance(data, list): + raise InvalidError(f"Output of batch function {function_name} must be a list.") + if len(data) != len(input_ids): + raise InvalidError( + f"Output of batch function {function_name} must be a list of the same length as its inputs." + ) + return await asyncio.gather( + *[ + self._get_kwargs_process_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) + for d in data + ] + ) + async def _push_output( self, input_ids: Union[str, List[str]], @@ -485,53 +526,20 @@ async def _push_output( data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs, ): - outputs = [] - if isinstance(input_ids, list): - formatted_data = None - if "data" in kwargs and kwargs["data"]: - # split the list of data in kwargs to respective input_ids - # report error for every input_id in batch call - if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: - formatted_data = [kwargs.pop("data")] * len(input_ids) - else: - data = self.deserialize_data_format(kwargs.pop("data"), data_format) - if not isinstance(data, list): - raise InvalidError(f"Output of batch function {function_name} must be a list.") - if len(data) != len(input_ids): - raise InvalidError( - f"Output of batch function {function_name} must be a list of the same length as its inputs." - ) - formatted_data = [self.serialize_data_format(d, data_format) for d in data] - for i, input_id in enumerate(input_ids): - data = formatted_data[i] if formatted_data else None - result = ( - ( - # upload data to S3 if too big. - api_pb2.GenericResult(data_blob_id=await blob_upload(data, self._client.stub), **kwargs) - if len(data) > MAX_OBJECT_SIZE_BYTES - else api_pb2.GenericResult(data=data, **kwargs) - ) - if data - else api_pb2.GenericResult(**kwargs) - ) - outputs.append( - api_pb2.FunctionPutOutputsItem( - input_id=input_id, - input_started_at=started_at, - output_created_at=time.time(), - result=result, - data_format=data_format, - ) + kwargs = await self._get_kwargs(kwargs, input_ids, function_name, data_format) + if isinstance(input_ids, list) and isinstance(kwargs, list): + outputs = [ + api_pb2.FunctionPutOutputsItem( + input_id=input_id, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, ) + for input_id, kwargs in zip(input_ids, kwargs) + ] else: - # upload data to S3 if too big. - if "data" in kwargs and kwargs["data"] and len(kwargs["data"]) > MAX_OBJECT_SIZE_BYTES: - data_blob_id = await blob_upload(kwargs["data"], self._client.stub) - # mutating kwargs. - del kwargs["data"] - kwargs["data_blob_id"] = data_blob_id - - outputs.append( + outputs = [ api_pb2.FunctionPutOutputsItem( input_id=input_ids, input_started_at=started_at, @@ -539,7 +547,7 @@ async def _push_output( result=api_pb2.GenericResult(**kwargs), data_format=data_format, ) - ) + ] await retry_transient_errors( self._client.stub.FunctionPutOutputs, api_pb2.FunctionPutOutputsRequest(outputs=outputs), From 651f838acdfa85f77daba890e8a1a07de69d8806 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 1 Aug 2024 18:21:25 +0000 Subject: [PATCH 03/23] fix type check --- modal/_container_io_manager.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 6a9c94ee0..8ba0cc57b 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -7,6 +7,7 @@ import sys import time import traceback +from dataclasses import dataclass from pathlib import Path from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union @@ -41,7 +42,14 @@ class Sentinel: """Used to get type-stubs to work with this object.""" +@dataclass class LocalInput: + input_id: str + function_call_id: str + method_name: str + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + def __init__( self, container_io_manager: "_ContainerIOManager", @@ -49,13 +57,10 @@ def __init__( function_call_id: str, input_pb: api_pb2.FunctionInput, ): - self.input_id: str = input_id - self.function_call_id: str = function_call_id - self.method_name: str = input_pb.method_name - - args, kwargs = container_io_manager.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.args: Tuple[Any, ...] = args - self.kwargs: Dict[str, Any] = kwargs + self.input_id = input_id + self.function_call_id = function_call_id + self.method_name = input_pb.method_name + self.args, self.kwargs = container_io_manager.deserialize(input_pb.args) if input_pb.args else ((), {}) container_io_manager.current_input_id = input_id container_io_manager.current_input_started_at = time.time() From b4f09d8c339e3420b39e1d0e4d9e4b652bf0c952 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 1 Aug 2024 18:30:56 +0000 Subject: [PATCH 04/23] cleanup --- modal/_container_entrypoint.py | 87 +++++++++++++++------------------- 1 file changed, 37 insertions(+), 50 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index f2a4a1e27..540030956 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -327,39 +327,42 @@ def _aggregate_args_and_kwargs( local_inputs: Union[LocalInput, List[LocalInput]], callable: Callable[..., Any], ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - if isinstance(local_inputs, list): - param_names = list(inspect.signature(callable).parameters.keys()) - for param in inspect.signature(callable).parameters.values(): - if param.default is not inspect.Parameter.empty: - raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") - - args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))] - for i, local_input in enumerate(local_inputs): - params_len = len(local_input.args) + len(local_input.kwargs) - if params_len != len(param_names): + if isinstance(local_inputs, LocalInput): + return local_inputs.input_id, local_inputs.function_call_id, local_inputs.args, local_inputs.kwargs + + # aggregate args and kwargs for batched input + input_ids = [local_input.input_id for local_input in local_inputs] + function_call_ids = [local_input.function_call_id for local_input in local_inputs] + param_names = list(inspect.signature(callable).parameters.keys()) + for param in inspect.signature(callable).parameters.values(): + if param.default is not inspect.Parameter.empty: + raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") + + args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))] + for i, local_input in enumerate(local_inputs): + params_len = len(local_input.args) + len(local_input.kwargs) + if params_len != len(param_names): + raise InvalidError( + f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa + ) + for j, arg in enumerate(local_input.args): + args_by_inputs[i][param_names[j]] = arg + for k, v in local_input.kwargs.items(): + if k not in param_names: raise InvalidError( - f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa + f"Modal batch function {callable.__name__} got an unexpected keyword argument {k} in one call." ) - for j, arg in enumerate(local_input.args): - args_by_inputs[i][param_names[j]] = arg - for k, v in local_input.kwargs.items(): - if k not in param_names: - raise InvalidError( - f"Modal batch function {callable.__name__} got an unexpected keyword argument {k} in one call." - ) - if k in args_by_inputs[i]: - raise InvalidError( - f"Modal batch function {callable.__name__} got multiple values for argument {k} in one call." - ) - args_by_inputs[i][k] = v - - formatted_kwargs = { - param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names - } - return tuple(), formatted_kwargs + if k in args_by_inputs[i]: + raise InvalidError( + f"Modal batch function {callable.__name__} got multiple values for argument {k} in one call." + ) + args_by_inputs[i][k] = v - else: - return local_inputs.args, local_inputs.kwargs + # put all arg / args into a kwargs dict, with items being param name -> list of values of length len(local_inputs) + formatted_kwargs = { + param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names + } + return input_ids, function_call_ids, tuple(), formatted_kwargs def call_function( @@ -376,17 +379,9 @@ async def run_input_async( container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() - input_ids: Union[str, List[str]] = ( - local_inputs.input_id - if isinstance(local_inputs, LocalInput) - else [local_input.input_id for local_input in local_inputs] - ) - function_call_ids: Union[str, List[str]] = ( - local_inputs.function_call_id - if isinstance(local_inputs, LocalInput) - else [local_input.function_call_id for local_input in local_inputs] + input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs( + local_inputs, finalized_function.callable ) - args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) reset_context = _set_current_context_ids(input_ids, function_call_ids) async with container_io_manager.handle_input_exception.aio( input_ids, started_at, finalized_function.callable.__name__ @@ -447,17 +442,9 @@ def run_input_sync( container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() - input_ids: Union[str, List[str]] = ( - local_inputs.input_id - if isinstance(local_inputs, LocalInput) - else [local_input.input_id for local_input in local_inputs] - ) - function_call_ids: Union[str, List[str]] = ( - local_inputs.function_call_id - if isinstance(local_inputs, LocalInput) - else [local_input.function_call_id for local_input in local_inputs] + input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs( + local_inputs, finalized_function.callable ) - args, kwargs = _aggregate_args_and_kwargs(local_inputs, finalized_function.callable) reset_context = _set_current_context_ids(input_ids, function_call_ids) with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__): logger.debug(f"Starting input {input_ids} (sync)") From 853c37b29fc26dc3a6cceb104ebec5227ce192d6 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 1 Aug 2024 18:40:12 +0000 Subject: [PATCH 05/23] fix type check --- modal/_container_entrypoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 540030956..ef76c01a5 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -326,7 +326,7 @@ def _sigint_handler(): def _aggregate_args_and_kwargs( local_inputs: Union[LocalInput, List[LocalInput]], callable: Callable[..., Any], -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: +) -> Tuple[Union[str, List[str]], Union[str, List[str]], Tuple[Any, ...], Dict[str, Any]]: if isinstance(local_inputs, LocalInput): return local_inputs.input_id, local_inputs.function_call_id, local_inputs.args, local_inputs.kwargs From d2a0f6bbbfd616a91ce7ab353190157ae34571ab Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 1 Aug 2024 23:19:24 +0000 Subject: [PATCH 06/23] reduce duplicate code --- modal/_container_entrypoint.py | 71 +++++++++++------ modal/_container_io_manager.py | 139 +++++++++++++-------------------- 2 files changed, 100 insertions(+), 110 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index ef76c01a5..c6bce9cc7 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -13,7 +13,7 @@ import typing from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple from google.protobuf.message import Message from synchronicity import Interface @@ -324,15 +324,18 @@ def _sigint_handler(): def _aggregate_args_and_kwargs( - local_inputs: Union[LocalInput, List[LocalInput]], + local_inputs: List[LocalInput], callable: Callable[..., Any], -) -> Tuple[Union[str, List[str]], Union[str, List[str]], Tuple[Any, ...], Dict[str, Any]]: - if isinstance(local_inputs, LocalInput): - return local_inputs.input_id, local_inputs.function_call_id, local_inputs.args, local_inputs.kwargs - - # aggregate args and kwargs for batched input + is_batched: bool, +) -> Tuple[List[str], List[str], Tuple[Any, ...], Dict[str, Any]]: input_ids = [local_input.input_id for local_input in local_inputs] function_call_ids = [local_input.function_call_id for local_input in local_inputs] + + if not is_batched: + assert len(local_inputs) == 1 + return input_ids, function_call_ids, local_inputs[0].args, local_inputs[0].kwargs + + # aggregate args and kwargs for batched input param_names = list(inspect.signature(callable).parameters.keys()) for param in inspect.signature(callable).parameters.values(): if param.default is not inspect.Parameter.empty: @@ -375,16 +378,17 @@ def call_function( ): async def run_input_async( finalized_function: FinalizedFunction, - local_inputs: Union[LocalInput, List[LocalInput]], + local_inputs: List[LocalInput], container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() + is_batched = container_io_manager.is_batched() input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs( - local_inputs, finalized_function.callable + local_inputs, finalized_function.callable, is_batched ) reset_context = _set_current_context_ids(input_ids, function_call_ids) async with container_io_manager.handle_input_exception.aio( - input_ids, started_at, finalized_function.callable.__name__ + input_ids, started_at, finalized_function.callable.__name__, is_batched ): logger.debug(f"Starting input {input_ids} (async)") res = finalized_function.callable(*args, **kwargs) @@ -394,16 +398,17 @@ async def run_input_async( if finalized_function.is_generator: if not inspect.isasyncgen(res): raise InvalidError(f"Async generator function returned value of type {type(res)}") - if isinstance(input_ids, list) or isinstance(function_call_ids, list): + if is_batched: raise InvalidError( f"Batch function {finalized_function.callable.__name__} cannot return generators." ) + assert len(function_call_ids) == 1 # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024) generator_output_task = asyncio.create_task( container_io_manager.generator_output_task.aio( - function_call_ids, + function_call_ids[0], finalized_function.data_format, generator_queue, ) @@ -419,6 +424,7 @@ async def run_input_async( message = api_pb2.GeneratorDone(items_total=item_count) await container_io_manager.push_output.aio( input_ids, + is_batched, started_at, finalized_function.callable.__name__, message, @@ -432,21 +438,29 @@ async def run_input_async( ) value = await res await container_io_manager.push_output.aio( - input_ids, started_at, finalized_function.callable.__name__, value, finalized_function.data_format + input_ids, + is_batched, + started_at, + finalized_function.callable.__name__, + value, + finalized_function.data_format, ) reset_context() def run_input_sync( finalized_function: FinalizedFunction, - local_inputs: Union[LocalInput, List[LocalInput]], + local_inputs: List[LocalInput], container_io_manager: "modal._container_io_manager.ContainerIOManager", ) -> None: started_at = time.time() + is_batched = container_io_manager.is_batched() input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs( - local_inputs, finalized_function.callable + local_inputs, finalized_function.callable, is_batched ) reset_context = _set_current_context_ids(input_ids, function_call_ids) - with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__): + with container_io_manager.handle_input_exception( + input_ids, started_at, finalized_function.callable.__name__, is_batched + ): logger.debug(f"Starting input {input_ids} (sync)") res = finalized_function.callable(*args, **kwargs) logger.debug(f"Finished input {input_ids} (sync)") @@ -455,15 +469,16 @@ def run_input_sync( if finalized_function.is_generator: if not inspect.isgenerator(res): raise InvalidError(f"Generator function returned value of type {type(res)}") - if isinstance(function_call_ids, list): + if is_batched: raise InvalidError( f"Batch function {finalized_function.callable.__name__} cannot return generators." ) + assert len(function_call_ids) == 1 # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore - function_call_ids, + function_call_ids[0], finalized_function.data_format, generator_queue, _future=True, # type: ignore # Synchronicity magic to return a future. @@ -479,6 +494,7 @@ def run_input_sync( message = api_pb2.GeneratorDone(items_total=item_count) container_io_manager.push_output( input_ids, + is_batched, started_at, finalized_function.callable.__name__, message, @@ -491,17 +507,20 @@ def run_input_sync( " You might need to use @app.function(..., is_generator=True)." ) container_io_manager.push_output( - input_ids, started_at, finalized_function.callable.__name__, res, finalized_function.data_format + input_ids, + is_batched, + started_at, + finalized_function.callable.__name__, + res, + finalized_function.data_format, ) reset_context() - def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction: - if isinstance(local_inputs, list): - assert len(local_inputs) > 0 - assert len(set([local_input.method_name for local_input in local_inputs])) == 1 - return finalized_functions[local_inputs[0].method_name] - else: - return finalized_functions[local_inputs.method_name] + def _get_finalized_functions(local_inputs: List[LocalInput]) -> FinalizedFunction: + assert len(local_inputs) > 0 + # all functions in a batch must have the same method name + assert len(set([local_input.method_name for local_input in local_inputs])) == 1 + return finalized_functions[local_inputs[0].method_name] if input_concurrency > 1: with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool: diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 8ba0cc57b..626ab5432 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -9,7 +9,7 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union +from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Tuple from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -114,6 +114,8 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self.current_input_started_at = None self._input_concurrency = None + self._batch_max_size = None + self._batch_linger_ms = None self._semaphore = None self._environment_name = container_args.environment_name @@ -137,6 +139,9 @@ def _reset_singleton(cls): """Only used for tests.""" cls._singleton = None + def is_batched(self) -> bool: + return self.function_def.batch_max_size > 0 + async def _run_heartbeat_loop(self): while 1: t0 = time.monotonic() @@ -367,7 +372,7 @@ def get_max_inputs_to_fetch(self): @synchronizer.no_io_translation async def _generate_inputs( self, - ) -> AsyncIterator[Union[Tuple[str, str, api_pb2.FunctionInput], List[Tuple[str, str, api_pb2.FunctionInput]]]]: + ) -> AsyncIterator[List[Tuple[str, str, api_pb2.FunctionInput]]]: request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id) eof_received = False iteration = 0 @@ -394,16 +399,18 @@ async def _generate_inputs( ) await asyncio.sleep(response.rate_limit_sleep_duration) elif response.inputs: - if self._batch_max_size == 0: - # for input cancellations and concurrency logic we currently assume - # that there is no input buffering in the container - assert len(response.inputs) == 1 - item = response.inputs[0] + # for input cancellations and concurrency logic we currently assume + # that there is no input buffering in the container + assert 0 < len(response.inputs) <= max(1, request.batch_max_size) + inputs_list = [] + for item in response.inputs: if item.kill_switch: + assert len(response.inputs) == 1 logger.debug(f"Task {self.task_id} input kill signal input.") eof_received = True break if item.input_id in self.cancelled_input_ids: + assert request.batch_max_size == 0 continue # If we got a pointer to a blob, download it from S3. @@ -412,45 +419,20 @@ async def _generate_inputs( else: input_pb = item.input - # If yielded, allow semaphore to be released via complete_call - yield (item.input_id, item.function_call_id, input_pb) - yielded = True + inputs_list.append((item.input_id, item.function_call_id, input_pb)) - # We only support max_inputs = 1 at the moment - if item.input.final_input or self.function_def.max_inputs == 1: + if item.input.final_input: eof_received = True - break - else: - assert len(response.inputs) <= request.batch_max_size - - inputs_list = [] - for item in response.inputs: - if item.kill_switch: - assert len(response.inputs) == 1 - logger.debug(f"Task {self.task_id} input kill signal input.") - eof_received = True - break - assert item.input_id not in self.cancelled_input_ids - - # If we got a pointer to a blob, download it from S3. - if item.input.WhichOneof("args_oneof") == "args_blob_id": - input_pb = await self.populate_input_blobs(item.input) - else: - input_pb = item.input - - inputs_list.append((item.input_id, item.function_call_id, input_pb)) - if item.input.final_input: - eof_received = True + if request.batch_max_size != 0: logger.error("Final input not expected in batch input stream") - break - - if not eof_received: - yield inputs_list - yielded = True + break - # We only support max_inputs = 1 at the moment - if self.function_def.max_inputs == 1: - eof_received = True + if not eof_received: + yield inputs_list + yielded = True + # We only support max_inputs = 1 at the moment + if self.function_def.max_inputs == 1: + eof_received = True finally: if not yielded: @@ -462,7 +444,7 @@ async def run_inputs_outputs( input_concurrency: int = 1, batch_max_size: int = 0, batch_linger_ms: int = 0, - ) -> AsyncIterator[Union[LocalInput, List[LocalInput]]]: + ) -> AsyncIterator[List[LocalInput]]: # Ensure we do not fetch new inputs when container is too busy. # Before trying to fetch an input, acquire the semaphore: # - if no input is fetched, release the semaphore. @@ -472,17 +454,12 @@ async def run_inputs_outputs( self._batch_linger_ms = batch_linger_ms self._semaphore = asyncio.Semaphore(input_concurrency) - if batch_max_size == 0: - async for input_id, function_call_id, input_pb in self._generate_inputs(): - yield LocalInput(self, input_id, function_call_id, input_pb) - self.current_input_id, self.current_input_started_at = (None, None) - else: - async for inputs_list in self._generate_inputs(): - local_inputs_list = [] - for input_id, function_call_id, input_pb in inputs_list: - local_inputs_list.append(LocalInput(self, input_id, function_call_id, input_pb)) - yield local_inputs_list - self.current_input_id, self.current_input_started_at = (None, None) + async for inputs_list in self._generate_inputs(): + local_inputs_list = [] + for input_id, function_call_id, input_pb in inputs_list: + local_inputs_list.append(LocalInput(self, input_id, function_call_id, input_pb)) + yield local_inputs_list + self.current_input_id, self.current_input_started_at = (None, None) # collect all active input slots, meaning all inputs have wrapped up. for _ in range(input_concurrency): @@ -496,13 +473,13 @@ async def _get_kwargs_process_blob_data(self, data: bytes, kwargs: Dict[str, Any return kwargs async def _get_kwargs( - self, kwargs: Dict[str, Any], input_ids: Union[str, List[str]], function_name: str, data_format: int - ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + self, kwargs: Dict[str, Any], input_ids: List[str], function_name: str, data_format: int, is_batched: bool + ) -> List[Dict[str, Any]]: if "data" not in kwargs or not kwargs["data"]: - return kwargs if isinstance(input_ids, str) else [kwargs] * len(input_ids) + return [kwargs] * len(input_ids) # data is not batched, return a single kwargs. - if isinstance(input_ids, str): - return await self._get_kwargs_process_blob_data(kwargs.pop("data"), kwargs) + if not is_batched: + return [await self._get_kwargs_process_blob_data(kwargs.pop("data"), kwargs)] # data is batched, return a list of kwargs # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: @@ -525,34 +502,24 @@ async def _get_kwargs( async def _push_output( self, - input_ids: Union[str, List[str]], + input_ids: List[str], + is_batched: bool, started_at: float, function_name: str, data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, **kwargs, ): - kwargs = await self._get_kwargs(kwargs, input_ids, function_name, data_format) - if isinstance(input_ids, list) and isinstance(kwargs, list): - outputs = [ - api_pb2.FunctionPutOutputsItem( - input_id=input_id, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - for input_id, kwargs in zip(input_ids, kwargs) - ] - else: - outputs = [ - api_pb2.FunctionPutOutputsItem( - input_id=input_ids, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - ] + kwargs = await self._get_kwargs(kwargs, input_ids, function_name, data_format, is_batched) + outputs = [ + api_pb2.FunctionPutOutputsItem( + input_id=input_ids, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, + ) + for input_ids, kwargs in zip(input_ids, kwargs) + ] await retry_transient_errors( self._client.stub.FunctionPutOutputs, api_pb2.FunctionPutOutputsRequest(outputs=outputs), @@ -616,7 +583,7 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: @asynccontextmanager async def handle_input_exception( - self, input_ids: Union[str, List[str]], started_at: float, function_name: str + self, input_ids: List[str], started_at: float, function_name: str, is_batched: bool ) -> AsyncGenerator[None, None]: """Handle an exception while processing a function input.""" try: @@ -661,6 +628,7 @@ async def handle_input_exception( await self._push_output( input_ids, + is_batched=is_batched, started_at=started_at, function_name=function_name, data_format=api_pb2.DATA_FORMAT_PICKLE, @@ -679,9 +647,12 @@ async def complete_call(self, started_at): self._semaphore.release() @synchronizer.no_io_translation - async def push_output(self, input_id, started_at: float, function_name: str, data: Any, data_format: int) -> None: + async def push_output( + self, input_id, is_batched: bool, started_at: float, function_name: str, data: Any, data_format: int + ) -> None: await self._push_output( input_id, + is_batched=is_batched, started_at=started_at, function_name=function_name, data_format=data_format, From 14896853fb159422bedf5ac0af9c825c172020eb Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 1 Aug 2024 23:22:06 +0000 Subject: [PATCH 07/23] local input change --- modal/_container_io_manager.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 626ab5432..d3c8e34f9 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -47,23 +47,8 @@ class LocalInput: input_id: str function_call_id: str method_name: str - args: Tuple[Any, ...] - kwargs: Dict[str, Any] - - def __init__( - self, - container_io_manager: "_ContainerIOManager", - input_id: str, - function_call_id: str, - input_pb: api_pb2.FunctionInput, - ): - self.input_id = input_id - self.function_call_id = function_call_id - self.method_name = input_pb.method_name - self.args, self.kwargs = container_io_manager.deserialize(input_pb.args) if input_pb.args else ((), {}) - - container_io_manager.current_input_id = input_id - container_io_manager.current_input_started_at = time.time() + args: Any + kwargs: Any class _ContainerIOManager: @@ -457,7 +442,9 @@ async def run_inputs_outputs( async for inputs_list in self._generate_inputs(): local_inputs_list = [] for input_id, function_call_id, input_pb in inputs_list: - local_inputs_list.append(LocalInput(self, input_id, function_call_id, input_pb)) + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) yield local_inputs_list self.current_input_id, self.current_input_started_at = (None, None) From 3b8c11fe403847d3ed71be5633cbed46fdb21536 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Sun, 4 Aug 2024 23:35:49 +0000 Subject: [PATCH 08/23] refactor --- modal/_container_entrypoint.py | 171 +++------------ modal/_container_io_manager.py | 370 +++++++++++++++++++-------------- modal/execution_context.py | 16 +- test/test_asgi_wrapper.py | 16 +- 4 files changed, 255 insertions(+), 318 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index c6bce9cc7..c390418cc 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -18,6 +18,7 @@ from google.protobuf.message import Message from synchronicity import Interface +from modal._container_io_manager import FinalizedFunction, IOContext from modal_proto import api_pb2 from ._asgi import ( @@ -28,7 +29,7 @@ webhook_asgi_app, wsgi_app_wrapper, ) -from ._container_io_manager import ContainerIOManager, LocalInput, UserException, _ContainerIOManager +from ._container_io_manager import ContainerIOManager, UserException, _ContainerIOManager from ._proxy_tunnel import proxy_tunnel from ._serialization import deserialize, deserialize_proto_params from ._utils.async_utils import TaskContext, synchronizer @@ -101,14 +102,6 @@ def construct_webhook_callable( raise InvalidError(f"Unrecognized web endpoint type {webhook_config.type}") -@dataclass -class FinalizedFunction: - callable: Callable[..., Any] - is_async: bool - is_generator: bool - data_format: int # api_pb2.DataFormat - - class Service(metaclass=ABCMeta): """Common interface for singular functions and class-based "services" @@ -323,51 +316,6 @@ def _sigint_handler(): self.loop.remove_signal_handler(signal.SIGINT) -def _aggregate_args_and_kwargs( - local_inputs: List[LocalInput], - callable: Callable[..., Any], - is_batched: bool, -) -> Tuple[List[str], List[str], Tuple[Any, ...], Dict[str, Any]]: - input_ids = [local_input.input_id for local_input in local_inputs] - function_call_ids = [local_input.function_call_id for local_input in local_inputs] - - if not is_batched: - assert len(local_inputs) == 1 - return input_ids, function_call_ids, local_inputs[0].args, local_inputs[0].kwargs - - # aggregate args and kwargs for batched input - param_names = list(inspect.signature(callable).parameters.keys()) - for param in inspect.signature(callable).parameters.values(): - if param.default is not inspect.Parameter.empty: - raise InvalidError(f"Modal batch function {callable.__name__} does not accept default arguments.") - - args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))] - for i, local_input in enumerate(local_inputs): - params_len = len(local_input.args) + len(local_input.kwargs) - if params_len != len(param_names): - raise InvalidError( - f"Modal batch function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa - ) - for j, arg in enumerate(local_input.args): - args_by_inputs[i][param_names[j]] = arg - for k, v in local_input.kwargs.items(): - if k not in param_names: - raise InvalidError( - f"Modal batch function {callable.__name__} got an unexpected keyword argument {k} in one call." - ) - if k in args_by_inputs[i]: - raise InvalidError( - f"Modal batch function {callable.__name__} got multiple values for argument {k} in one call." - ) - args_by_inputs[i][k] = v - - # put all arg / args into a kwargs dict, with items being param name -> list of values of length len(local_inputs) - formatted_kwargs = { - param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names - } - return input_ids, function_call_ids, tuple(), formatted_kwargs - - def call_function( user_code_event_loop: UserCodeEventLoop, container_io_manager: "modal._container_io_manager.ContainerIOManager", @@ -376,40 +324,23 @@ def call_function( batch_max_size: Optional[int], batch_linger_ms: Optional[int], ): - async def run_input_async( - finalized_function: FinalizedFunction, - local_inputs: List[LocalInput], - container_io_manager: "modal._container_io_manager.ContainerIOManager", - ) -> None: + async def run_input_async(io_context: IOContext) -> None: started_at = time.time() - is_batched = container_io_manager.is_batched() - input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs( - local_inputs, finalized_function.callable, is_batched - ) + input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids reset_context = _set_current_context_ids(input_ids, function_call_ids) - async with container_io_manager.handle_input_exception.aio( - input_ids, started_at, finalized_function.callable.__name__, is_batched - ): - logger.debug(f"Starting input {input_ids} (async)") - res = finalized_function.callable(*args, **kwargs) - logger.debug(f"Finished input {input_ids} (async)") - + async with container_io_manager.handle_input_exception.aio(io_context, started_at): + res = io_context.call_finalized_function() # TODO(erikbern): any exception below shouldn't be considered a user exception - if finalized_function.is_generator: + if io_context.finalized_function.is_generator: if not inspect.isasyncgen(res): raise InvalidError(f"Async generator function returned value of type {type(res)}") - if is_batched: - raise InvalidError( - f"Batch function {finalized_function.callable.__name__} cannot return generators." - ) - assert len(function_call_ids) == 1 # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024) generator_output_task = asyncio.create_task( container_io_manager.generator_output_task.aio( function_call_ids[0], - finalized_function.data_format, + io_context.finalized_function.data_format, generator_queue, ) ) @@ -423,10 +354,8 @@ async def run_input_async( await generator_output_task # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) await container_io_manager.push_output.aio( - input_ids, - is_batched, + io_context, started_at, - finalized_function.callable.__name__, message, api_pb2.DATA_FORMAT_GENERATOR_DONE, ) @@ -438,48 +367,30 @@ async def run_input_async( ) value = await res await container_io_manager.push_output.aio( - input_ids, - is_batched, + io_context, started_at, - finalized_function.callable.__name__, value, - finalized_function.data_format, + io_context.finalized_function.data_format, ) reset_context() - def run_input_sync( - finalized_function: FinalizedFunction, - local_inputs: List[LocalInput], - container_io_manager: "modal._container_io_manager.ContainerIOManager", - ) -> None: + def run_input_sync(io_context: IOContext) -> None: started_at = time.time() - is_batched = container_io_manager.is_batched() - input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs( - local_inputs, finalized_function.callable, is_batched - ) + input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids reset_context = _set_current_context_ids(input_ids, function_call_ids) - with container_io_manager.handle_input_exception( - input_ids, started_at, finalized_function.callable.__name__, is_batched - ): - logger.debug(f"Starting input {input_ids} (sync)") - res = finalized_function.callable(*args, **kwargs) - logger.debug(f"Finished input {input_ids} (sync)") + with container_io_manager.handle_input_exception(io_context, started_at): + res = io_context.call_finalized_function() # TODO(erikbern): any exception below shouldn't be considered a user exception - if finalized_function.is_generator: + if io_context.finalized_function.is_generator: if not inspect.isgenerator(res): raise InvalidError(f"Generator function returned value of type {type(res)}") - if is_batched: - raise InvalidError( - f"Batch function {finalized_function.callable.__name__} cannot return generators." - ) - assert len(function_call_ids) == 1 # Send up to this many outputs at a time. generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore function_call_ids[0], - finalized_function.data_format, + io_context.finalized_function.data_format, generator_queue, _future=True, # type: ignore # Synchronicity magic to return a future. ) @@ -492,36 +403,16 @@ def run_input_sync( container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) generator_output_task.result() # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) - container_io_manager.push_output( - input_ids, - is_batched, - started_at, - finalized_function.callable.__name__, - message, - api_pb2.DATA_FORMAT_GENERATOR_DONE, - ) + container_io_manager.push_output(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) else: if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): raise InvalidError( f"Sync (non-generator) function return value of type {type(res)}." " You might need to use @app.function(..., is_generator=True)." ) - container_io_manager.push_output( - input_ids, - is_batched, - started_at, - finalized_function.callable.__name__, - res, - finalized_function.data_format, - ) + container_io_manager.push_output(io_context, started_at, res, io_context.finalized_function.data_format) reset_context() - def _get_finalized_functions(local_inputs: List[LocalInput]) -> FinalizedFunction: - assert len(local_inputs) > 0 - # all functions in a batch must have the same method name - assert len(set([local_input.method_name for local_input in local_inputs])) == 1 - return finalized_functions[local_inputs[0].method_name] - if input_concurrency > 1: with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool: @@ -530,28 +421,26 @@ async def run_concurrent_inputs(): # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s # for them to resolve gracefully: async with TaskContext(0.01) as task_context: - async for local_inputs in container_io_manager.run_inputs_outputs.aio( - input_concurrency, batch_max_size, batch_linger_ms + async for io_context in container_io_manager.run_inputs_outputs.aio( + finalized_functions, input_concurrency, batch_max_size, batch_linger_ms ): - finalized_function = _get_finalized_functions(local_inputs) # Note that run_inputs_outputs will not return until the concurrency semaphore has # released all its slots so that they can be acquired by the run_inputs_outputs finalizer # This prevents leaving the task_context before outputs have been created # TODO: refactor to make this a bit more easy to follow? - if finalized_function.is_async: - task_context.create_task( - run_input_async(finalized_function, local_inputs, container_io_manager) - ) + if io_context.finalized_function.is_async: + task_context.create_task(run_input_async(io_context)) else: # run sync input in thread - thread_pool.submit(run_input_sync, finalized_function, local_inputs, container_io_manager) + thread_pool.submit(run_input_sync, io_context) user_code_event_loop.run(run_concurrent_inputs()) else: - for local_inputs in container_io_manager.run_inputs_outputs(input_concurrency, batch_max_size, batch_linger_ms): - finalized_function = _get_finalized_functions(local_inputs) - if finalized_function.is_async: - user_code_event_loop.run(run_input_async(finalized_function, local_inputs, container_io_manager)) + for io_context in container_io_manager.run_inputs_outputs( + finalized_functions, input_concurrency, batch_max_size, batch_linger_ms + ): + if io_context.finalized_function.is_async: + user_code_event_loop.run(run_input_async(io_context)) else: # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation # during function execution. This is sent to cancel inputs from the user @@ -562,7 +451,7 @@ def _cancel_input_signal_handler(signum, stackframe): # run this sync code in the main thread, blocking the "userland" event loop # this lets us cancel it using a signal handler that raises an exception try: - run_input_sync(finalized_function, local_inputs, container_io_manager) + run_input_sync(io_context) finally: signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index d3c8e34f9..54a5ab37c 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -1,5 +1,6 @@ # Copyright Modal Labs 2024 import asyncio +import inspect import json import math import os @@ -43,12 +44,159 @@ class Sentinel: @dataclass -class LocalInput: - input_id: str - function_call_id: str - method_name: str - args: Any - kwargs: Any +class FinalizedFunction: + callable: Callable[..., Any] + is_async: bool + is_generator: bool + data_format: int # api_pb2.DataFormat + + +class IOContext: + input_ids: List[str] + function_call_ids: List[str] + finalized_function: FinalizedFunction + + async def create( + container_io_manager: "_ContainerIOManager", + finalized_functions: Dict[str, FinalizedFunction], + inputs: List[Tuple[str, str, api_pb2.FunctionInput]], + is_batched: bool, + ) -> None: + self = IOContext() + assert len(inputs) > 0 + self.input_ids, self.function_call_ids, self.inputs = zip(*inputs) + self.is_batched = is_batched + + self.inputs = await asyncio.gather(*[self.populate_input_blobs(input) for input in self.inputs]) + # check every input in batch executes the same function + method_name = self.inputs[0].method_name + assert all(method_name == input.method_name for input in self.inputs) + self.finalized_function = finalized_functions[method_name] + self.deserialized_args = [ + container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in self.inputs + ] + return self + + async def populate_input_blobs(self, input: api_pb2.FunctionInput): + # If we got a pointer to a blob, download it from S3. + if input.WhichOneof("args_oneof") == "args_blob_id": + args = await blob_download(input.args_blob_id, self._client.stub) + + # Mutating + input.ClearField("args_blob_id") + input.args = args + return input + + return input + + def _args_and_kwargs(self): + if not self.is_batched: + assert len(self.inputs) == 1 + return self.deserialized_args[0] + + func_name = self.finalized_function.callable.__name__ + # batched function cannot be generator + if self.finalized_function.is_generator: + raise InvalidError(f"Modal batched function {func_name} cannot be a generator.") + + # batched function cannot have default arguments + param_names = [] + for param in inspect.signature(self.finalized_function.callable).parameters.values(): + param_names.append(param.name) + if param.default is not inspect.Parameter.empty: + raise InvalidError(f"Modal batched function {func_name} does not accept default arguments.") + + # aggregate args and kwargs of all inputs into a kwarg dict + kwargs_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(self.input_ids))] + for i, (args, kwargs) in enumerate(self.deserialized_args): + # check that all batched inputs should have the same number of args and kwargs + if (num_params := len(args) + len(kwargs)) != len(param_names): + raise InvalidError( + f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one call has {num_params}.." # noqa + ) + + for j, arg in enumerate(args): + kwargs_by_inputs[i][param_names[j]] = arg + for k, v in kwargs.items(): + if k not in param_names: + raise InvalidError( + f"Modal batched function {func_name} got unexpected keyword argument {k} in one call." + ) + if k in kwargs_by_inputs[i]: + raise InvalidError( + f"Modal batched function {func_name} got multiple values for argument {k} in one call." + ) + kwargs_by_inputs[i][k] = v + + formatted_kwargs = { + param_name: [kwargs_by_inputs[i][param_name] for i in range(len(kwargs_by_inputs))] + for param_name in param_names + } + return (), formatted_kwargs + + def call_finalized_function(self) -> Any: + args, kwargs = self._args_and_kwargs() + logger.debug(f"Starting input {self.input_ids} (async)") + res = self.finalized_function.callable(*args, **kwargs) + logger.debug(f"Finished input {self.input_ids} (async)") + return res + + def serialize_data_format(self, obj: Any, data_format: int) -> bytes: + return serialize_data_format(obj, data_format) + + def deserialize_data_format(self, data: bytes, data_format: int) -> Any: + return deserialize_data_format(data, data_format, self._client) + + async def _format_data(self, data: bytes, kwargs: Dict[str, Any], blob_func: Callable) -> Dict[str, Any]: + if len(data) > MAX_OBJECT_SIZE_BYTES: + kwargs["data_blob_id"] = await blob_func(data) + else: + kwargs["data"] = data + return kwargs + + @synchronizer.no_io_translation + async def format_output( + self, started_at: float, data_format: int, blob_func: Callable, **kwargs + ) -> List[api_pb2.FunctionPutOutputsItem]: + if "data" not in kwargs: + kwargs_list = [kwargs] * len(self.input_ids) + # data is not batched, return a single kwargs. + elif not self.is_batched and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: + data = self.serialize_data_format(kwargs.pop("data"), data_format) + kwargs_list = [await self._format_data(data, kwargs, blob_func)] + elif not self.is_batched: # data is not batched and is an exception + kwargs_list = [await self._format_data(kwargs.pop("data"), kwargs, blob_func)] + + # data is batched, return a list of kwargs + # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. + elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: + error_data = kwargs.pop("data") + kwargs_list = await asyncio.gather( + *[self._format_data(error_data, kwargs, blob_func) for _ in self.input_ids] + ) + else: + function_name = self.finalized_function.callable.__name__ + data = kwargs.pop("data") + if not isinstance(data, list): + raise InvalidError(f"Output of batch function {function_name} must be a list.") + if len(data) != len(self.input_ids): + raise InvalidError( + f"Output of batch function {function_name} must be a list of the same length as its inputs." + ) + kwargs_list = await asyncio.gather( + *[self._format_data(self.serialize_data_format(d, data_format), kwargs.copy(), blob_func) for d in data] + ) + + return [ + api_pb2.FunctionPutOutputsItem( + input_id=input_id, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, + ) + for input_id, kwargs in zip(self.input_ids, kwargs_list) + ] class _ContainerIOManager: @@ -99,8 +247,6 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self.current_input_started_at = None self._input_concurrency = None - self._batch_max_size = None - self._batch_linger_ms = None self._semaphore = None self._environment_name = container_args.environment_name @@ -124,9 +270,6 @@ def _reset_singleton(cls): """Only used for tests.""" cls._singleton = None - def is_batched(self) -> bool: - return self.function_def.batch_max_size > 0 - async def _run_heartbeat_loop(self): while 1: t0 = time.monotonic() @@ -261,12 +404,8 @@ def serialize(self, obj: Any) -> bytes: def deserialize(self, data: bytes) -> Any: return deserialize(data, self._client) - @synchronizer.no_io_translation - def serialize_data_format(self, obj: Any, data_format: int) -> bytes: - return serialize_data_format(obj, data_format) - - def deserialize_data_format(self, data: bytes, data_format: int) -> Any: - return deserialize_data_format(data, data_format, self._client) + async def blob_upload(self, data: bytes) -> str: + return await blob_upload(data, self._client.stub) async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]: """Read from the `data_in` stream of a function call.""" @@ -334,14 +473,6 @@ async def _queue_put(self, queue: asyncio.Queue, value: Any) -> None: """Put a value onto a queue, using the synchronicity event loop.""" await queue.put(value) - async def populate_input_blobs(self, item: api_pb2.FunctionInput): - args = await blob_download(item.args_blob_id, self._client.stub) - - # Mutating - item.ClearField("args_blob_id") - item.args = args - return item - def get_average_call_time(self) -> float: if self.calls_completed == 0: return 0 @@ -357,156 +488,85 @@ def get_max_inputs_to_fetch(self): @synchronizer.no_io_translation async def _generate_inputs( self, + batch_max_size: int, + batch_linger_ms: int, ) -> AsyncIterator[List[Tuple[str, str, api_pb2.FunctionInput]]]: request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id) - eof_received = False iteration = 0 - while not eof_received and self._fetching_inputs: + while self._fetching_inputs: request.average_call_time = self.get_average_call_time() request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove. request.input_concurrency = self._input_concurrency - request.batch_max_size = self._batch_max_size - request.batch_linger_ms = self._batch_linger_ms + request.batch_max_size, request.batch_linger_ms = batch_max_size, batch_linger_ms await self._semaphore.acquire() - yielded = False - try: - # If number of active inputs is at max queue size, this will block. - iteration += 1 - response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors( - self._client.stub.FunctionGetInputs, request - ) + # If number of active inputs is at max queue size, this will block. + iteration += 1 + response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors( + self._client.stub.FunctionGetInputs, request + ) - if response.rate_limit_sleep_duration: - logger.info( - "Task exceeded rate limit, sleeping for %.2fs before trying again." - % response.rate_limit_sleep_duration - ) - await asyncio.sleep(response.rate_limit_sleep_duration) - elif response.inputs: - # for input cancellations and concurrency logic we currently assume - # that there is no input buffering in the container - assert 0 < len(response.inputs) <= max(1, request.batch_max_size) - inputs_list = [] - for item in response.inputs: + if response.rate_limit_sleep_duration: + logger.info( + "Task exceeded rate limit, sleeping for %.2fs before trying again." + % response.rate_limit_sleep_duration + ) + await asyncio.sleep(response.rate_limit_sleep_duration) + self._semaphore.release() + elif response.inputs: + # for input cancellations and concurrency logic we currently assume + # that there is no input buffering in the container + assert 0 < len(response.inputs) <= max(1, request.batch_max_size) + inputs = [] + for item in response.inputs: + if item.kill_switch or item.input.final_input: if item.kill_switch: - assert len(response.inputs) == 1 logger.debug(f"Task {self.task_id} input kill signal input.") - eof_received = True - break - if item.input_id in self.cancelled_input_ids: - assert request.batch_max_size == 0 - continue - - # If we got a pointer to a blob, download it from S3. - if item.input.WhichOneof("args_oneof") == "args_blob_id": - input_pb = await self.populate_input_blobs(item.input) - else: - input_pb = item.input - - inputs_list.append((item.input_id, item.function_call_id, input_pb)) - - if item.input.final_input: - eof_received = True - if request.batch_max_size != 0: - logger.error("Final input not expected in batch input stream") - break - - if not eof_received: - yield inputs_list - yielded = True - # We only support max_inputs = 1 at the moment - if self.function_def.max_inputs == 1: - eof_received = True + if item.input.final_input and request.batch_max_size > 0: + logger.debug(f"Task {self.task_id} Final input not expected in batch input stream") + self._semaphore.release() + return + if item.input_id in self.cancelled_input_ids: + continue - finally: - if not yielded: - self._semaphore.release() + inputs.append((item.input_id, item.function_call_id, item.input)) + + yield inputs + # We only support max_inputs = 1 at the moment + if self.function_def.max_inputs == 1: + return @synchronizer.no_io_translation async def run_inputs_outputs( self, + finalized_functions: Dict[str, FinalizedFunction], input_concurrency: int = 1, batch_max_size: int = 0, batch_linger_ms: int = 0, - ) -> AsyncIterator[List[LocalInput]]: + ) -> AsyncIterator[IOContext]: # Ensure we do not fetch new inputs when container is too busy. # Before trying to fetch an input, acquire the semaphore: # - if no input is fetched, release the semaphore. # - or, when the output for the fetched input is sent, release the semaphore. self._input_concurrency = input_concurrency - self._batch_max_size = batch_max_size - self._batch_linger_ms = batch_linger_ms self._semaphore = asyncio.Semaphore(input_concurrency) - async for inputs_list in self._generate_inputs(): - local_inputs_list = [] - for input_id, function_call_id, input_pb in inputs_list: - args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) - self.current_input_id, self.current_input_started_at = (input_id, time.time()) - local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) - yield local_inputs_list + async for inputs in self._generate_inputs(batch_max_size, batch_linger_ms): + io_context = await IOContext.create(self, finalized_functions, inputs, batch_max_size > 0) + # TODO(Cathy) investigate this thing when current_input_id is list + self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time() + yield io_context self.current_input_id, self.current_input_started_at = (None, None) # collect all active input slots, meaning all inputs have wrapped up. for _ in range(input_concurrency): await self._semaphore.acquire() - async def _get_kwargs_process_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: - if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await blob_upload(data, self._client.stub) - else: - kwargs["data"] = data - return kwargs - - async def _get_kwargs( - self, kwargs: Dict[str, Any], input_ids: List[str], function_name: str, data_format: int, is_batched: bool - ) -> List[Dict[str, Any]]: - if "data" not in kwargs or not kwargs["data"]: - return [kwargs] * len(input_ids) - # data is not batched, return a single kwargs. - if not is_batched: - return [await self._get_kwargs_process_blob_data(kwargs.pop("data"), kwargs)] - # data is batched, return a list of kwargs - # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. - if "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: - error_data = kwargs.pop("data") - return [await self._get_kwargs_process_blob_data(error_data, kwargs) for _ in input_ids] - else: - data = self.deserialize_data_format(kwargs.pop("data"), data_format) - if not isinstance(data, list): - raise InvalidError(f"Output of batch function {function_name} must be a list.") - if len(data) != len(input_ids): - raise InvalidError( - f"Output of batch function {function_name} must be a list of the same length as its inputs." - ) - return await asyncio.gather( - *[ - self._get_kwargs_process_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) - for d in data - ] - ) - + @synchronizer.no_io_translation async def _push_output( self, - input_ids: List[str], - is_batched: bool, - started_at: float, - function_name: str, - data_format=api_pb2.DATA_FORMAT_UNSPECIFIED, - **kwargs, + outputs: List[api_pb2.FunctionPutOutputsItem], ): - kwargs = await self._get_kwargs(kwargs, input_ids, function_name, data_format, is_batched) - outputs = [ - api_pb2.FunctionPutOutputsItem( - input_id=input_ids, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - for input_ids, kwargs in zip(input_ids, kwargs) - ] await retry_transient_errors( self._client.stub.FunctionPutOutputs, api_pb2.FunctionPutOutputsRequest(outputs=outputs), @@ -570,7 +630,9 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: @asynccontextmanager async def handle_input_exception( - self, input_ids: List[str], started_at: float, function_name: str, is_batched: bool + self, + io_context: IOContext, + started_at: float, ) -> AsyncGenerator[None, None]: """Handle an exception while processing a function input.""" try: @@ -586,13 +648,9 @@ async def handle_input_exception( # just skip creating any output for this input and keep going with the next instead # it should have been marked as cancelled already in the backend at this point so it # won't be retried - logger.warning(f"The current input ({input_ids=}) was cancelled by a user request") + logger.warning(f"The current input ({io_context.input_ids=}) was cancelled by a user request") await self.complete_call(started_at) return - except InvalidError as exc: - # If there is an error in batch function output, we need to explicitly raise it - if "Output of batch function" in exc.args[0]: - raise except BaseException as exc: # print exception so it's logged traceback.print_exc() @@ -613,12 +671,10 @@ async def handle_input_exception( repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000] repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" - await self._push_output( - input_ids, - is_batched=is_batched, + outputs = await io_context.format_output( started_at=started_at, - function_name=function_name, data_format=api_pb2.DATA_FORMAT_PICKLE, + blob_func=self.blob_upload, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, data=self.serialize_exception(exc), exception=repr_exc, @@ -626,6 +682,7 @@ async def handle_input_exception( serialized_tb=serialized_tb, tb_line_cache=tb_line_cache, ) + await self._push_output(outputs) await self.complete_call(started_at) async def complete_call(self, started_at): @@ -634,18 +691,15 @@ async def complete_call(self, started_at): self._semaphore.release() @synchronizer.no_io_translation - async def push_output( - self, input_id, is_batched: bool, started_at: float, function_name: str, data: Any, data_format: int - ) -> None: - await self._push_output( - input_id, - is_batched=is_batched, + async def push_output(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: + outputs = await io_context.format_output( started_at=started_at, - function_name=function_name, data_format=data_format, + blob_func=self.blob_upload, + data=data, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, - data=self.serialize_data_format(data, data_format), ) + await self._push_output(outputs) await self.complete_call(started_at) async def memory_restore(self) -> None: diff --git a/modal/execution_context.py b/modal/execution_context.py index bf1bd399d..340313292 100644 --- a/modal/execution_context.py +++ b/modal/execution_context.py @@ -1,6 +1,6 @@ # Copyright Modal Labs 2024 from contextvars import ContextVar -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional from modal._container_io_manager import _ContainerIOManager from modal._utils.async_utils import synchronize_api @@ -70,16 +70,10 @@ def process_stuff(): return None -def _set_current_context_ids( - input_ids: Union[str, List[str]], function_call_ids: Union[str, List[str]] -) -> Callable[[], None]: - if isinstance(input_ids, list): - assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0 - input_id = input_ids[0] - function_call_id = function_call_ids[0] - else: - input_id = input_ids - function_call_id = function_call_ids +def _set_current_context_ids(input_ids: List[str], function_call_ids: List[str]) -> Callable[[], None]: + assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0 + input_id = input_ids[0] + function_call_id = function_call_ids[0] input_token = _current_input_id.set(input_id) function_call_token = _current_function_call_id.set(function_call_id) diff --git a/test/test_asgi_wrapper.py b/test/test_asgi_wrapper.py index 4c71e2574..8d8aa0317 100644 --- a/test/test_asgi_wrapper.py +++ b/test/test_asgi_wrapper.py @@ -70,7 +70,7 @@ async def aio(_function_call_id): @pytest.mark.timeout(1) async def test_success(): mock_manager = MockIOManager() - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, mock_manager) asgi_scope = _asgi_get_scope("/") outputs = [output async for output in wrapped_app(asgi_scope)] @@ -88,7 +88,7 @@ async def test_success(): @pytest.mark.timeout(1) async def test_endpoint_exception(endpoint_url): mock_manager = MockIOManager() - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, mock_manager) asgi_scope = _asgi_get_scope(endpoint_url) outputs = [] @@ -121,7 +121,7 @@ async def test_broken_io_unused(caplog): # any of the body data, it should be allowed to output its data # and not raise an exception - but print a warning since it's unexpected mock_manager = BrokenIOManager() - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, mock_manager) asgi_scope = _asgi_get_scope("/") outputs = [] @@ -140,7 +140,7 @@ async def test_broken_io_unused(caplog): @pytest.mark.timeout(10) async def test_broken_io_used(): mock_manager = BrokenIOManager() - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, mock_manager) asgi_scope = _asgi_get_scope("/async_reading_body", "POST") outputs = [] @@ -164,7 +164,7 @@ async def aio(_function_call_id): @pytest.mark.timeout(2) async def test_first_message_timeout(monkeypatch): monkeypatch.setattr("modal._asgi.FIRST_MESSAGE_TIMEOUT_SECONDS", 0.1) # simulate timeout - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, SlowIOManager()) asgi_scope = _asgi_get_scope("/async_reading_body", "POST") outputs = [] @@ -180,7 +180,7 @@ async def test_first_message_timeout(monkeypatch): async def test_cancellation_cleanup(caplog): # this test mostly exists to get some coverage on the cancellation/error paths and # ensure nothing unexpected happens there - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, SlowIOManager()) asgi_scope = _asgi_get_scope("/async_reading_body", "POST") outputs = [] @@ -199,7 +199,7 @@ async def app_runner(): @pytest.mark.asyncio async def test_streaming_response(): - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, SlowIOManager()) asgi_scope = _asgi_get_scope("/streaming_response", "GET") outputs = [] @@ -225,7 +225,7 @@ async def aio(_function_call_id): @pytest.mark.asyncio async def test_streaming_body(): - _set_current_context_ids("in-123", "fc-123") + _set_current_context_ids(["in-123"], ["fc-123"]) wrapped_app = asgi_app_wrapper(app, StreamingIOManager()) asgi_scope = _asgi_get_scope("/async_reading_body", "POST") From 69d701f810181112645e70e83756b29645843acb Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 14:12:05 +0000 Subject: [PATCH 09/23] fix type checck --- modal/_container_entrypoint.py | 3 +- modal/_container_io_manager.py | 82 +++++++++++++++++----------------- 2 files changed, 43 insertions(+), 42 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index c390418cc..84473f79d 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -18,7 +18,6 @@ from google.protobuf.message import Message from synchronicity import Interface -from modal._container_io_manager import FinalizedFunction, IOContext from modal_proto import api_pb2 from ._asgi import ( @@ -29,7 +28,7 @@ webhook_asgi_app, wsgi_app_wrapper, ) -from ._container_io_manager import ContainerIOManager, UserException, _ContainerIOManager +from ._container_io_manager import ContainerIOManager, FinalizedFunction, IOContext, UserException, _ContainerIOManager from ._proxy_tunnel import proxy_tunnel from ._serialization import deserialize, deserialize_proto_params from ._utils.async_utils import TaskContext, synchronizer diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 54a5ab37c..051157df1 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -19,7 +19,7 @@ from modal_proto import api_pb2 -from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format +from ._serialization import deserialize, serialize, serialize_data_format from ._traceback import extract_traceback from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload @@ -56,42 +56,42 @@ class IOContext: function_call_ids: List[str] finalized_function: FinalizedFunction + @classmethod async def create( + cls, container_io_manager: "_ContainerIOManager", finalized_functions: Dict[str, FinalizedFunction], inputs: List[Tuple[str, str, api_pb2.FunctionInput]], is_batched: bool, - ) -> None: - self = IOContext() - assert len(inputs) > 0 - self.input_ids, self.function_call_ids, self.inputs = zip(*inputs) - self.is_batched = is_batched - - self.inputs = await asyncio.gather(*[self.populate_input_blobs(input) for input in self.inputs]) + ) -> "IOContext": + self = cls.__new__(cls) + assert len(inputs) >= 1 if is_batched else len(inputs) == 1 + self.input_ids, self.function_call_ids, inputs = zip(*inputs) + self.inputs = await asyncio.gather( + *[self._populate_input_blobs(container_io_manager, input) for input in inputs] + ) # check every input in batch executes the same function method_name = self.inputs[0].method_name assert all(method_name == input.method_name for input in self.inputs) self.finalized_function = finalized_functions[method_name] self.deserialized_args = [ - container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in self.inputs + container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs ] + self.is_batched = is_batched return self - async def populate_input_blobs(self, input: api_pb2.FunctionInput): + async def _populate_input_blobs(self, container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): # If we got a pointer to a blob, download it from S3. if input.WhichOneof("args_oneof") == "args_blob_id": - args = await blob_download(input.args_blob_id, self._client.stub) - + args = await container_io_manager.blob_download(input.args_blob_id) # Mutating input.ClearField("args_blob_id") input.args = args - return input return input def _args_and_kwargs(self): if not self.is_batched: - assert len(self.inputs) == 1 return self.deserialized_args[0] func_name = self.finalized_function.callable.__name__ @@ -141,38 +141,27 @@ def call_finalized_function(self) -> Any: logger.debug(f"Finished input {self.input_ids} (async)") return res - def serialize_data_format(self, obj: Any, data_format: int) -> bytes: - return serialize_data_format(obj, data_format) - - def deserialize_data_format(self, data: bytes, data_format: int) -> Any: - return deserialize_data_format(data, data_format, self._client) - - async def _format_data(self, data: bytes, kwargs: Dict[str, Any], blob_func: Callable) -> Dict[str, Any]: - if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await blob_func(data) - else: - kwargs["data"] = data - return kwargs - @synchronizer.no_io_translation - async def format_output( - self, started_at: float, data_format: int, blob_func: Callable, **kwargs + async def format_outputs( + self, container_io_manager: "_ContainerIOManager", started_at: float, data_format: int, **kwargs ) -> List[api_pb2.FunctionPutOutputsItem]: if "data" not in kwargs: kwargs_list = [kwargs] * len(self.input_ids) # data is not batched, return a single kwargs. - elif not self.is_batched and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: - data = self.serialize_data_format(kwargs.pop("data"), data_format) - kwargs_list = [await self._format_data(data, kwargs, blob_func)] - elif not self.is_batched: # data is not batched and is an exception - kwargs_list = [await self._format_data(kwargs.pop("data"), kwargs, blob_func)] + elif not self.is_batched: + data = ( + serialize_data_format(kwargs.pop("data"), data_format) + if kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + else kwargs.pop("data") + ) + kwargs_list = [await container_io_manager.format_blob_data(data, kwargs)] # data is batched, return a list of kwargs # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: error_data = kwargs.pop("data") kwargs_list = await asyncio.gather( - *[self._format_data(error_data, kwargs, blob_func) for _ in self.input_ids] + *[container_io_manager.format_blob_data(error_data, kwargs) for _ in self.input_ids] ) else: function_name = self.finalized_function.callable.__name__ @@ -184,7 +173,10 @@ async def format_output( f"Output of batch function {function_name} must be a list of the same length as its inputs." ) kwargs_list = await asyncio.gather( - *[self._format_data(self.serialize_data_format(d, data_format), kwargs.copy(), blob_func) for d in data] + *[ + container_io_manager.format_blob_data(serialize_data_format(d, data_format), kwargs.copy()) + for d in data + ] ) return [ @@ -407,6 +399,16 @@ def deserialize(self, data: bytes) -> Any: async def blob_upload(self, data: bytes) -> str: return await blob_upload(data, self._client.stub) + async def blob_download(self, blob_id: str) -> bytes: + return await blob_download(blob_id, self._client.stub) + + async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: + if len(data) > MAX_OBJECT_SIZE_BYTES: + kwargs["data_blob_id"] = await self.blob_upload(data) + else: + kwargs["data"] = data + return kwargs + async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]: """Read from the `data_in` stream of a function call.""" async for data in _stream_function_call_data(self._client, function_call_id, "data_in"): @@ -671,10 +673,10 @@ async def handle_input_exception( repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000] repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" - outputs = await io_context.format_output( + outputs = await io_context.format_outputs( + container_io_manager=self, started_at=started_at, data_format=api_pb2.DATA_FORMAT_PICKLE, - blob_func=self.blob_upload, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, data=self.serialize_exception(exc), exception=repr_exc, @@ -692,10 +694,10 @@ async def complete_call(self, started_at): @synchronizer.no_io_translation async def push_output(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: - outputs = await io_context.format_output( + outputs = await io_context.format_outputs( + container_io_manager=self, started_at=started_at, data_format=data_format, - blob_func=self.blob_upload, data=data, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, ) From cdb4ef99286f5c2eba951f21fed6ccfab7b38cda Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 14:20:33 +0000 Subject: [PATCH 10/23] cleanup IOContest init --- modal/_container_io_manager.py | 53 ++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 051157df1..4007b8cde 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -56,6 +56,20 @@ class IOContext: function_call_ids: List[str] finalized_function: FinalizedFunction + def __init__( + self, + input_ids: List[str], + function_call_ids: List[str], + finalized_function: FinalizedFunction, + deserialized_args: List, + is_batched: bool, + ): + self.input_ids = input_ids + self.function_call_ids = function_call_ids + self.finalized_function = finalized_function + self.deserialized_args = deserialized_args + self.is_batched = is_batched + @classmethod async def create( cls, @@ -64,31 +78,28 @@ async def create( inputs: List[Tuple[str, str, api_pb2.FunctionInput]], is_batched: bool, ) -> "IOContext": - self = cls.__new__(cls) assert len(inputs) >= 1 if is_batched else len(inputs) == 1 - self.input_ids, self.function_call_ids, inputs = zip(*inputs) - self.inputs = await asyncio.gather( - *[self._populate_input_blobs(container_io_manager, input) for input in inputs] - ) + input_ids, function_call_ids, inputs = zip(*inputs) + + async def _populate_input_blobs(container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): + # If we got a pointer to a blob, download it from S3. + if input.WhichOneof("args_oneof") == "args_blob_id": + args = await container_io_manager.blob_download(input.args_blob_id) + # Mutating + input.ClearField("args_blob_id") + input.args = args + + return input + + inputs = await asyncio.gather(*[_populate_input_blobs(container_io_manager, input) for input in inputs]) # check every input in batch executes the same function - method_name = self.inputs[0].method_name - assert all(method_name == input.method_name for input in self.inputs) - self.finalized_function = finalized_functions[method_name] - self.deserialized_args = [ + method_name = inputs[0].method_name + assert all(method_name == input.method_name for input in inputs) + finalized_function = finalized_functions[method_name] + deserialized_args = [ container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs ] - self.is_batched = is_batched - return self - - async def _populate_input_blobs(self, container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): - # If we got a pointer to a blob, download it from S3. - if input.WhichOneof("args_oneof") == "args_blob_id": - args = await container_io_manager.blob_download(input.args_blob_id) - # Mutating - input.ClearField("args_blob_id") - input.args = args - - return input + return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched) def _args_and_kwargs(self): if not self.is_batched: From 0aeaad322052f940669c9e11b172bf05a58a2196 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 14:32:42 +0000 Subject: [PATCH 11/23] fix synchronize --- modal/_container_io_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 4007b8cde..106c8baa5 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -152,7 +152,6 @@ def call_finalized_function(self) -> Any: logger.debug(f"Finished input {self.input_ids} (async)") return res - @synchronizer.no_io_translation async def format_outputs( self, container_io_manager: "_ContainerIOManager", started_at: float, data_format: int, **kwargs ) -> List[api_pb2.FunctionPutOutputsItem]: @@ -575,7 +574,6 @@ async def run_inputs_outputs( for _ in range(input_concurrency): await self._semaphore.acquire() - @synchronizer.no_io_translation async def _push_output( self, outputs: List[api_pb2.FunctionPutOutputsItem], From 24b9689a7e5592562dcf6e1995adaa558f604830 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 16:58:34 +0000 Subject: [PATCH 12/23] cleanup --- modal/_container_io_manager.py | 182 ++++++++++++++++----------------- 1 file changed, 90 insertions(+), 92 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 106c8baa5..534ff0e25 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -67,8 +67,8 @@ def __init__( self.input_ids = input_ids self.function_call_ids = function_call_ids self.finalized_function = finalized_function - self.deserialized_args = deserialized_args - self.is_batched = is_batched + self._deserialized_args = deserialized_args + self._is_batched = is_batched @classmethod async def create( @@ -96,14 +96,15 @@ async def _populate_input_blobs(container_io_manager: "_ContainerIOManager", inp method_name = inputs[0].method_name assert all(method_name == input.method_name for input in inputs) finalized_function = finalized_functions[method_name] + # TODO(cathy) Performance decrease if we deserialize inputs later deserialized_args = [ container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs ] return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched) def _args_and_kwargs(self): - if not self.is_batched: - return self.deserialized_args[0] + if not self._is_batched: + return self._deserialized_args[0] func_name = self.finalized_function.callable.__name__ # batched function cannot be generator @@ -119,7 +120,7 @@ def _args_and_kwargs(self): # aggregate args and kwargs of all inputs into a kwarg dict kwargs_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(self.input_ids))] - for i, (args, kwargs) in enumerate(self.deserialized_args): + for i, (args, kwargs) in enumerate(self._deserialized_args): # check that all batched inputs should have the same number of args and kwargs if (num_params := len(args) + len(kwargs)) != len(param_names): raise InvalidError( @@ -152,53 +153,18 @@ def call_finalized_function(self) -> Any: logger.debug(f"Finished input {self.input_ids} (async)") return res - async def format_outputs( - self, container_io_manager: "_ContainerIOManager", started_at: float, data_format: int, **kwargs - ) -> List[api_pb2.FunctionPutOutputsItem]: - if "data" not in kwargs: - kwargs_list = [kwargs] * len(self.input_ids) - # data is not batched, return a single kwargs. - elif not self.is_batched: - data = ( - serialize_data_format(kwargs.pop("data"), data_format) - if kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS - else kwargs.pop("data") - ) - kwargs_list = [await container_io_manager.format_blob_data(data, kwargs)] - - # data is batched, return a list of kwargs - # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. - elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: - error_data = kwargs.pop("data") - kwargs_list = await asyncio.gather( - *[container_io_manager.format_blob_data(error_data, kwargs) for _ in self.input_ids] - ) - else: + def validate_output_data(self, data: Any) -> None: + if self._is_batched: function_name = self.finalized_function.callable.__name__ - data = kwargs.pop("data") if not isinstance(data, list): raise InvalidError(f"Output of batch function {function_name} must be a list.") if len(data) != len(self.input_ids): raise InvalidError( f"Output of batch function {function_name} must be a list of the same length as its inputs." ) - kwargs_list = await asyncio.gather( - *[ - container_io_manager.format_blob_data(serialize_data_format(d, data_format), kwargs.copy()) - for d in data - ] - ) - - return [ - api_pb2.FunctionPutOutputsItem( - input_id=input_id, - input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), - data_format=data_format, - ) - for input_id, kwargs in zip(self.input_ids, kwargs_list) - ] + else: + data = [data] + return data class _ContainerIOManager: @@ -406,6 +372,10 @@ def serialize(self, obj: Any) -> bytes: def deserialize(self, data: bytes) -> Any: return deserialize(data, self._client) + @synchronizer.no_io_translation + def serialize_data_format(self, obj: Any, data_format: int) -> bytes: + return serialize_data_format(obj, data_format) + async def blob_upload(self, data: bytes) -> str: return await blob_upload(data, self._client.stub) @@ -512,41 +482,49 @@ async def _generate_inputs( request.batch_max_size, request.batch_linger_ms = batch_max_size, batch_linger_ms await self._semaphore.acquire() - # If number of active inputs is at max queue size, this will block. - iteration += 1 - response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors( - self._client.stub.FunctionGetInputs, request - ) - - if response.rate_limit_sleep_duration: - logger.info( - "Task exceeded rate limit, sleeping for %.2fs before trying again." - % response.rate_limit_sleep_duration + yielded = False + try: + # If number of active inputs is at max queue size, this will block. + iteration += 1 + response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors( + self._client.stub.FunctionGetInputs, request ) - await asyncio.sleep(response.rate_limit_sleep_duration) - self._semaphore.release() - elif response.inputs: - # for input cancellations and concurrency logic we currently assume - # that there is no input buffering in the container - assert 0 < len(response.inputs) <= max(1, request.batch_max_size) - inputs = [] - for item in response.inputs: - if item.kill_switch or item.input.final_input: + + if response.rate_limit_sleep_duration: + logger.info( + "Task exceeded rate limit, sleeping for %.2fs before trying again." + % response.rate_limit_sleep_duration + ) + await asyncio.sleep(response.rate_limit_sleep_duration) + elif response.inputs: + # for input cancellations and concurrency logic we currently assume + # that there is no input buffering in the container + assert 0 < len(response.inputs) <= max(1, request.batch_max_size) + inputs = [] + final_input_received = False + for item in response.inputs: if item.kill_switch: logger.debug(f"Task {self.task_id} input kill signal input.") - if item.input.final_input and request.batch_max_size > 0: - logger.debug(f"Task {self.task_id} Final input not expected in batch input stream") - self._semaphore.release() + return + if item.input_id in self.cancelled_input_ids: + continue + + inputs.append((item.input_id, item.function_call_id, item.input)) + if item.input.final_input: + if request.batch_max_size > 0: + logger.debug(f"Task {self.task_id} Final input not expected in batch input stream") + final_input_received = True + break + + yield inputs + yielded = True + + # We only support max_inputs = 1 at the moment + if final_input_received or self.function_def.max_inputs == 1: return - if item.input_id in self.cancelled_input_ids: - continue - - inputs.append((item.input_id, item.function_call_id, item.input)) - - yield inputs - # We only support max_inputs = 1 at the moment - if self.function_def.max_inputs == 1: - return + finally: + if not yielded: + self._semaphore.release() @synchronizer.no_io_translation async def run_inputs_outputs( @@ -565,7 +543,6 @@ async def run_inputs_outputs( async for inputs in self._generate_inputs(batch_max_size, batch_linger_ms): io_context = await IOContext.create(self, finalized_functions, inputs, batch_max_size > 0) - # TODO(Cathy) investigate this thing when current_input_id is list self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time() yield io_context self.current_input_id, self.current_input_started_at = (None, None) @@ -574,10 +551,33 @@ async def run_inputs_outputs( for _ in range(input_concurrency): await self._semaphore.acquire() - async def _push_output( - self, - outputs: List[api_pb2.FunctionPutOutputsItem], - ): + @synchronizer.no_io_translation + async def format_and_push_outputs( + self, io_context: IOContext, started_at: float, data_format: int, **kwargs + ) -> None: + if "data" not in kwargs: + kwargs_list = [kwargs] * len(io_context.input_ids) + elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: + exc_data = kwargs.pop("data") + # if batched, duplicate exception to all inputs + kwargs_list = await asyncio.gather(*[self.format_blob_data(exc_data, kwargs) for _ in io_context.input_ids]) + else: + data = io_context.validate_output_data(kwargs.pop("data")) + # if batched, split the list of data to all inputs + kwargs_list = await asyncio.gather( + *[self.format_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) for d in data] + ) + + outputs = [ + api_pb2.FunctionPutOutputsItem( + input_id=input_id, + input_started_at=started_at, + output_created_at=time.time(), + result=api_pb2.GenericResult(**kwargs), + data_format=data_format, + ) + for input_id, kwargs in zip(io_context.input_ids, kwargs_list) + ] await retry_transient_errors( self._client.stub.FunctionPutOutputs, api_pb2.FunctionPutOutputsRequest(outputs=outputs), @@ -659,8 +659,8 @@ async def handle_input_exception( # just skip creating any output for this input and keep going with the next instead # it should have been marked as cancelled already in the backend at this point so it # won't be retried - logger.warning(f"The current input ({io_context.input_ids=}) was cancelled by a user request") - await self.complete_call(started_at) + logger.warning(f"Received a cancellation signal while processing input {io_context.input_ids}") + await self.exit_context(started_at) return except BaseException as exc: # print exception so it's logged @@ -682,8 +682,8 @@ async def handle_input_exception( repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000] repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" - outputs = await io_context.format_outputs( - container_io_manager=self, + await self.format_and_push_outputs( + io_context=io_context, started_at=started_at, data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, @@ -693,25 +693,23 @@ async def handle_input_exception( serialized_tb=serialized_tb, tb_line_cache=tb_line_cache, ) - await self._push_output(outputs) - await self.complete_call(started_at) + await self.exit_context(started_at) - async def complete_call(self, started_at): + async def exit_context(self, started_at): self.total_user_time += time.time() - started_at self.calls_completed += 1 self._semaphore.release() @synchronizer.no_io_translation async def push_output(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: - outputs = await io_context.format_outputs( - container_io_manager=self, + await self.format_and_push_outputs( + io_context=io_context, started_at=started_at, data_format=data_format, data=data, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, ) - await self._push_output(outputs) - await self.complete_call(started_at) + await self.exit_context(started_at) async def memory_restore(self) -> None: # Busy-wait for restore. `/__modal/restore-state.json` is created From 4d34354e2937fdef4d7821ce32928eead2604b19 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 18:18:49 +0000 Subject: [PATCH 13/23] semantics change --- modal/_container_io_manager.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 534ff0e25..3f3727982 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -102,7 +102,7 @@ async def _populate_input_blobs(container_io_manager: "_ContainerIOManager", inp ] return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched) - def _args_and_kwargs(self): + def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: if not self._is_batched: return self._deserialized_args[0] @@ -153,7 +153,7 @@ def call_finalized_function(self) -> Any: logger.debug(f"Finished input {self.input_ids} (async)") return res - def validate_output_data(self, data: Any) -> None: + def validate_output_data(self, data: Any) -> List[Any]: if self._is_batched: function_name = self.finalized_function.callable.__name__ if not isinstance(data, list): @@ -376,15 +376,12 @@ def deserialize(self, data: bytes) -> Any: def serialize_data_format(self, obj: Any, data_format: int) -> bytes: return serialize_data_format(obj, data_format) - async def blob_upload(self, data: bytes) -> str: - return await blob_upload(data, self._client.stub) - async def blob_download(self, blob_id: str) -> bytes: return await blob_download(blob_id, self._client.stub) async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await self.blob_upload(data) + kwargs["data_blob_id"] = await blob_upload(data, self._client.stub) else: kwargs["data"] = data return kwargs @@ -516,6 +513,7 @@ async def _generate_inputs( final_input_received = True break + # If yielded, allow semaphore to be released via exit_context yield inputs yielded = True @@ -558,7 +556,7 @@ async def format_and_push_outputs( if "data" not in kwargs: kwargs_list = [kwargs] * len(io_context.input_ids) elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: - exc_data = kwargs.pop("data") + exc_data = self.serialize_exception(kwargs.pop("data")) # if batched, duplicate exception to all inputs kwargs_list = await asyncio.gather(*[self.format_blob_data(exc_data, kwargs) for _ in io_context.input_ids]) else: @@ -687,7 +685,7 @@ async def handle_input_exception( started_at=started_at, data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, - data=self.serialize_exception(exc), + data=exc, exception=repr_exc, traceback=traceback.format_exc(), serialized_tb=serialized_tb, From 725b126e7f227644205072e06ffd55fb92485a66 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 19:41:02 +0000 Subject: [PATCH 14/23] function name change --- modal/_container_entrypoint.py | 10 ++++++---- modal/_container_io_manager.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 84473f79d..5dd530b61 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -352,7 +352,7 @@ async def run_input_async(io_context: IOContext) -> None: await container_io_manager._queue_put.aio(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) await generator_output_task # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) - await container_io_manager.push_output.aio( + await container_io_manager.push_outputs.aio( io_context, started_at, message, @@ -365,7 +365,7 @@ async def run_input_async(io_context: IOContext) -> None: " You might need to use @app.function(..., is_generator=True)." ) value = await res - await container_io_manager.push_output.aio( + await container_io_manager.format_and_push_outputs.aio( io_context, started_at, value, @@ -402,14 +402,16 @@ def run_input_sync(io_context: IOContext) -> None: container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) generator_output_task.result() # Wait to finish sending generator outputs. message = api_pb2.GeneratorDone(items_total=item_count) - container_io_manager.push_output(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) + container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) else: if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): raise InvalidError( f"Sync (non-generator) function return value of type {type(res)}." " You might need to use @app.function(..., is_generator=True)." ) - container_io_manager.push_output(io_context, started_at, res, io_context.finalized_function.data_format) + container_io_manager.push_outputs( + io_context, started_at, res, io_context.finalized_function.data_format + ) reset_context() if input_concurrency > 1: diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 3f3727982..4852e12be 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -699,7 +699,7 @@ async def exit_context(self, started_at): self._semaphore.release() @synchronizer.no_io_translation - async def push_output(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: + async def push_outputs(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: await self.format_and_push_outputs( io_context=io_context, started_at=started_at, From 77a220c8b84729fce707575fa2a978590f36b8ff Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 19:54:36 +0000 Subject: [PATCH 15/23] fix --- modal/_container_entrypoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 5dd530b61..b3fb65848 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -365,7 +365,7 @@ async def run_input_async(io_context: IOContext) -> None: " You might need to use @app.function(..., is_generator=True)." ) value = await res - await container_io_manager.format_and_push_outputs.aio( + await container_io_manager.push_outputs.aio( io_context, started_at, value, From 78f7fae9f94426e6093e69853defb78b32ed427a Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 7 Aug 2024 13:23:56 +0000 Subject: [PATCH 16/23] cleanup push outputs --- modal/_container_io_manager.py | 111 +++++++++++++++++---------------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 4852e12be..2c16150c0 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -73,7 +73,7 @@ def __init__( @classmethod async def create( cls, - container_io_manager: "_ContainerIOManager", + client: _Client, finalized_functions: Dict[str, FinalizedFunction], inputs: List[Tuple[str, str, api_pb2.FunctionInput]], is_batched: bool, @@ -81,25 +81,23 @@ async def create( assert len(inputs) >= 1 if is_batched else len(inputs) == 1 input_ids, function_call_ids, inputs = zip(*inputs) - async def _populate_input_blobs(container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): + async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput): # If we got a pointer to a blob, download it from S3. if input.WhichOneof("args_oneof") == "args_blob_id": - args = await container_io_manager.blob_download(input.args_blob_id) + args = await blob_download(input.args_blob_id, client) # Mutating input.ClearField("args_blob_id") input.args = args return input - inputs = await asyncio.gather(*[_populate_input_blobs(container_io_manager, input) for input in inputs]) + inputs = await asyncio.gather(*[_populate_input_blobs(client, input) for input in inputs]) # check every input in batch executes the same function method_name = inputs[0].method_name assert all(method_name == input.method_name for input in inputs) finalized_function = finalized_functions[method_name] # TODO(cathy) Performance decrease if we deserialize inputs later - deserialized_args = [ - container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs - ] + deserialized_args = [deserialize(input.args, client) if input.args else ((), {}) for input in inputs] return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched) def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: @@ -186,6 +184,8 @@ class _ContainerIOManager: current_input_id: Optional[str] current_input_started_at: Optional[float] + client: _Client + _input_concurrency: Optional[int] _semaphore: Optional[asyncio.Semaphore] _environment_name: str @@ -196,8 +196,6 @@ class _ContainerIOManager: _is_interactivity_enabled: bool _fetching_inputs: bool - _client: _Client - _GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel() _singleton: ClassVar[Optional["_ContainerIOManager"]] = None @@ -225,8 +223,8 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._is_interactivity_enabled = False self._fetching_inputs = True - self._client = client - assert isinstance(self._client, _Client) + self.client = client + assert isinstance(self.client, _Client) def __new__(cls, container_args: api_pb2.ContainerArguments, client: _Client) -> "_ContainerIOManager": cls._singleton = super().__new__(cls) @@ -280,7 +278,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: # TODO(erikbern): capture exceptions? response = await retry_transient_errors( - self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT + self.client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT ) if response.HasField("cancel_input_event"): @@ -332,7 +330,7 @@ def stop_heartbeat(self): async def get_app_objects(self) -> RunningApp: req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True) - resp = await retry_transient_errors(self._client.stub.AppGetObjects, req) + resp = await retry_transient_errors(self.client.stub.AppGetObjects, req) logger.debug(f"AppGetObjects received {len(resp.items)} objects for app {self.app_id}") tag_to_object_id = {} @@ -353,7 +351,7 @@ async def get_app_objects(self) -> RunningApp: async def get_serialized_function(self) -> Tuple[Optional[Any], Callable]: # Fetch the serialized function definition request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id) - response = await self._client.stub.FunctionGetSerialized(request) + response = await self.client.stub.FunctionGetSerialized(request) if response.function_serialized: fun = self.deserialize(response.function_serialized) else: @@ -370,25 +368,25 @@ def serialize(self, obj: Any) -> bytes: return serialize(obj) def deserialize(self, data: bytes) -> Any: - return deserialize(data, self._client) + return deserialize(data, self.client) @synchronizer.no_io_translation def serialize_data_format(self, obj: Any, data_format: int) -> bytes: return serialize_data_format(obj, data_format) async def blob_download(self, blob_id: str) -> bytes: - return await blob_download(blob_id, self._client.stub) + return await blob_download(blob_id, self.client.stub) async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await blob_upload(data, self._client.stub) + kwargs["data_blob_id"] = await blob_upload(data, self.client.stub) else: kwargs["data"] = data return kwargs async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]: """Read from the `data_in` stream of a function call.""" - async for data in _stream_function_call_data(self._client, function_call_id, "data_in"): + async for data in _stream_function_call_data(self.client, function_call_id, "data_in"): yield data async def put_data_out( @@ -408,13 +406,13 @@ async def put_data_out( for i, message_bytes in enumerate(messages_bytes): chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore if len(message_bytes) > MAX_OBJECT_SIZE_BYTES: - chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub) + chunk.data_blob_id = await blob_upload(message_bytes, self.client.stub) else: chunk.data = message_bytes data_chunks.append(chunk) req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks) - await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req) + await retry_transient_errors(self.client.stub.FunctionCallPutDataOut, req) async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None: """Task that feeds generator outputs into a function call's `data_out` stream.""" @@ -484,7 +482,7 @@ async def _generate_inputs( # If number of active inputs is at max queue size, this will block. iteration += 1 response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors( - self._client.stub.FunctionGetInputs, request + self.client.stub.FunctionGetInputs, request ) if response.rate_limit_sleep_duration: @@ -540,7 +538,7 @@ async def run_inputs_outputs( self._semaphore = asyncio.Semaphore(input_concurrency) async for inputs in self._generate_inputs(batch_max_size, batch_linger_ms): - io_context = await IOContext.create(self, finalized_functions, inputs, batch_max_size > 0) + io_context = await IOContext.create(self.client, finalized_functions, inputs, batch_max_size > 0) self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time() yield io_context self.current_input_id, self.current_input_started_at = (None, None) @@ -550,34 +548,22 @@ async def run_inputs_outputs( await self._semaphore.acquire() @synchronizer.no_io_translation - async def format_and_push_outputs( - self, io_context: IOContext, started_at: float, data_format: int, **kwargs + async def _push_outputs( + self, io_context: IOContext, started_at: float, data_format: int, results: List[Any] ) -> None: - if "data" not in kwargs: - kwargs_list = [kwargs] * len(io_context.input_ids) - elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: - exc_data = self.serialize_exception(kwargs.pop("data")) - # if batched, duplicate exception to all inputs - kwargs_list = await asyncio.gather(*[self.format_blob_data(exc_data, kwargs) for _ in io_context.input_ids]) - else: - data = io_context.validate_output_data(kwargs.pop("data")) - # if batched, split the list of data to all inputs - kwargs_list = await asyncio.gather( - *[self.format_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) for d in data] - ) - + output_created_at = time.time() outputs = [ api_pb2.FunctionPutOutputsItem( input_id=input_id, input_started_at=started_at, - output_created_at=time.time(), - result=api_pb2.GenericResult(**kwargs), + output_created_at=output_created_at, + result=result, data_format=data_format, ) - for input_id, kwargs in zip(io_context.input_ids, kwargs_list) + for input_id, result in zip(io_context.input_ids, results) ] await retry_transient_errors( - self._client.stub.FunctionPutOutputs, + self.client.stub.FunctionPutOutputs, api_pb2.FunctionPutOutputsRequest(outputs=outputs), additional_status_codes=[Status.RESOURCE_EXHAUSTED], max_retries=None, # Retry indefinitely, trying every 1s. @@ -603,6 +589,19 @@ def serialize_traceback(self, exc: BaseException) -> Tuple[Optional[bytes], Opti return serialized_tb, tb_line_cache + async def format_exception_results(self, io_context: IOContext, data: Any, **kwargs) -> List[api_pb2.GenericResult]: + kwargs = await self.format_blob_data(self.serialize_exception(data), kwargs) + return [api_pb2.GenericResult(**kwargs) for _ in io_context.input_ids] + + async def format_output_results( + self, io_context: IOContext, data: Any, data_format: int, **kwargs + ) -> api_pb2.GenericResult: + data = io_context.validate_output_data(data) + kwargs_list = await asyncio.gather( + *[self.format_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) for d in data] + ) + return [api_pb2.GenericResult(**kwargs) for kwargs in kwargs_list] + @asynccontextmanager async def handle_user_exception(self) -> AsyncGenerator[None, None]: """Sets the task as failed in a way where it's not retried. @@ -632,7 +631,7 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: ) req = api_pb2.TaskResultRequest(result=result) - await retry_transient_errors(self._client.stub.TaskResult, req) + await retry_transient_errors(self.client.stub.TaskResult, req) # Shut down the task gracefully raise UserException() @@ -680,10 +679,8 @@ async def handle_input_exception( repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000] repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" - await self.format_and_push_outputs( + results = await self.format_exception_results( io_context=io_context, - started_at=started_at, - data_format=api_pb2.DATA_FORMAT_PICKLE, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, data=exc, exception=repr_exc, @@ -691,6 +688,12 @@ async def handle_input_exception( serialized_tb=serialized_tb, tb_line_cache=tb_line_cache, ) + await self._push_outputs( + io_context=io_context, + started_at=started_at, + data_format=api_pb2.DATA_FORMAT_PICKLE, + results=results, + ) await self.exit_context(started_at) async def exit_context(self, started_at): @@ -700,12 +703,14 @@ async def exit_context(self, started_at): @synchronizer.no_io_translation async def push_outputs(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: - await self.format_and_push_outputs( + results = await self.format_output_results( + io_context, data, data_format, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + ) + await self._push_outputs( io_context=io_context, started_at=started_at, data_format=data_format, - data=data, - status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, + results=results, ) await self.exit_context(started_at) @@ -761,7 +766,7 @@ async def memory_restore(self) -> None: "CUDA device availability may be inaccurate." ) - self._client = await _Client.from_env() + self.client = await _Client.from_env() async def memory_snapshot(self) -> None: """Message server indicating that function is ready to be checkpointed.""" @@ -775,11 +780,11 @@ async def memory_snapshot(self) -> None: self._waiting_for_memory_snapshot = True self._heartbeat_condition.notify_all() - await self._client.stub.ContainerCheckpoint( + await self.client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) ) - await self._client._close(forget_credentials=True) + await self.client._close(forget_credentials=True) logger.debug("Memory snapshot request sent. Connection closed.") await self.memory_restore() @@ -800,7 +805,7 @@ async def volume_commit(self, volume_ids: List[str]) -> None: results = await asyncio.gather( *[ retry_transient_errors( - self._client.stub.VolumeCommit, + self.client.stub.VolumeCommit, api_pb2.VolumeCommitRequest(volume_id=v_id), max_retries=9, base_delay=0.25, @@ -838,7 +843,7 @@ async def interact(self): # todo(nathan): add warning if concurrency limit > 1. but idk how to check this here # todo(nathan): check if function interactivity is enabled try: - await self._client.stub.FunctionStartPtyShell(Empty()) + await self.client.stub.FunctionStartPtyShell(Empty()) except Exception as e: print("Error: Failed to start PTY shell.") raise e From d459a0e535d55d772d4c57dc0bc8aa014e6b4794 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 7 Aug 2024 13:27:31 +0000 Subject: [PATCH 17/23] remove blob download --- modal/_container_io_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 2c16150c0..1ddccbd1e 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -374,9 +374,6 @@ def deserialize(self, data: bytes) -> Any: def serialize_data_format(self, obj: Any, data_format: int) -> bytes: return serialize_data_format(obj, data_format) - async def blob_download(self, blob_id: str) -> bytes: - return await blob_download(blob_id, self.client.stub) - async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: if len(data) > MAX_OBJECT_SIZE_BYTES: kwargs["data_blob_id"] = await blob_upload(data, self.client.stub) From 593453e7ffe60574f9584bb9ad9f4c1e7c188fb1 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 7 Aug 2024 13:46:14 +0000 Subject: [PATCH 18/23] reword errors --- modal/_container_io_manager.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 1ddccbd1e..7f1691742 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -105,16 +105,11 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: return self._deserialized_args[0] func_name = self.finalized_function.callable.__name__ - # batched function cannot be generator - if self.finalized_function.is_generator: - raise InvalidError(f"Modal batched function {func_name} cannot be a generator.") # batched function cannot have default arguments param_names = [] for param in inspect.signature(self.finalized_function.callable).parameters.values(): param_names.append(param.name) - if param.default is not inspect.Parameter.empty: - raise InvalidError(f"Modal batched function {func_name} does not accept default arguments.") # aggregate args and kwargs of all inputs into a kwarg dict kwargs_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(self.input_ids))] @@ -122,7 +117,7 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: # check that all batched inputs should have the same number of args and kwargs if (num_params := len(args) + len(kwargs)) != len(param_names): raise InvalidError( - f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one call has {num_params}.." # noqa + f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one function call in the batch has {num_params}.." # noqa ) for j, arg in enumerate(args): @@ -130,17 +125,16 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: for k, v in kwargs.items(): if k not in param_names: raise InvalidError( - f"Modal batched function {func_name} got unexpected keyword argument {k} in one call." + f"Modal batched function {func_name} got unexpected keyword argument {k} in one function call in the batch." # noqa ) if k in kwargs_by_inputs[i]: raise InvalidError( - f"Modal batched function {func_name} got multiple values for argument {k} in one call." + f"Modal batched function {func_name} got multiple values for argument {k} in one function call in the batch." # noqa ) kwargs_by_inputs[i][k] = v formatted_kwargs = { - param_name: [kwargs_by_inputs[i][param_name] for i in range(len(kwargs_by_inputs))] - for param_name in param_names + param_name: [kwargs[param_name] for kwargs in kwargs_by_inputs] for param_name in param_names } return (), formatted_kwargs From 3715ec4a5b56f077be81bac9a3ef5f7138fa471a Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Wed, 7 Aug 2024 17:48:19 +0000 Subject: [PATCH 19/23] revert _client --- modal/_container_io_manager.py | 46 +++++++++++++++++----------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 7f1691742..8a11487ff 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -178,8 +178,6 @@ class _ContainerIOManager: current_input_id: Optional[str] current_input_started_at: Optional[float] - client: _Client - _input_concurrency: Optional[int] _semaphore: Optional[asyncio.Semaphore] _environment_name: str @@ -190,6 +188,8 @@ class _ContainerIOManager: _is_interactivity_enabled: bool _fetching_inputs: bool + _client: _Client + _GENERATOR_STOP_SENTINEL: ClassVar[Sentinel] = Sentinel() _singleton: ClassVar[Optional["_ContainerIOManager"]] = None @@ -217,8 +217,8 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._is_interactivity_enabled = False self._fetching_inputs = True - self.client = client - assert isinstance(self.client, _Client) + self._client = client + assert isinstance(self._client, _Client) def __new__(cls, container_args: api_pb2.ContainerArguments, client: _Client) -> "_ContainerIOManager": cls._singleton = super().__new__(cls) @@ -272,7 +272,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: # TODO(erikbern): capture exceptions? response = await retry_transient_errors( - self.client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT + self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT ) if response.HasField("cancel_input_event"): @@ -324,7 +324,7 @@ def stop_heartbeat(self): async def get_app_objects(self) -> RunningApp: req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True) - resp = await retry_transient_errors(self.client.stub.AppGetObjects, req) + resp = await retry_transient_errors(self._client.stub.AppGetObjects, req) logger.debug(f"AppGetObjects received {len(resp.items)} objects for app {self.app_id}") tag_to_object_id = {} @@ -345,7 +345,7 @@ async def get_app_objects(self) -> RunningApp: async def get_serialized_function(self) -> Tuple[Optional[Any], Callable]: # Fetch the serialized function definition request = api_pb2.FunctionGetSerializedRequest(function_id=self.function_id) - response = await self.client.stub.FunctionGetSerialized(request) + response = await self._client.stub.FunctionGetSerialized(request) if response.function_serialized: fun = self.deserialize(response.function_serialized) else: @@ -362,7 +362,7 @@ def serialize(self, obj: Any) -> bytes: return serialize(obj) def deserialize(self, data: bytes) -> Any: - return deserialize(data, self.client) + return deserialize(data, self._client) @synchronizer.no_io_translation def serialize_data_format(self, obj: Any, data_format: int) -> bytes: @@ -370,14 +370,14 @@ def serialize_data_format(self, obj: Any, data_format: int) -> bytes: async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await blob_upload(data, self.client.stub) + kwargs["data_blob_id"] = await blob_upload(data, self._client.stub) else: kwargs["data"] = data return kwargs async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]: """Read from the `data_in` stream of a function call.""" - async for data in _stream_function_call_data(self.client, function_call_id, "data_in"): + async for data in _stream_function_call_data(self._client, function_call_id, "data_in"): yield data async def put_data_out( @@ -397,13 +397,13 @@ async def put_data_out( for i, message_bytes in enumerate(messages_bytes): chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore if len(message_bytes) > MAX_OBJECT_SIZE_BYTES: - chunk.data_blob_id = await blob_upload(message_bytes, self.client.stub) + chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub) else: chunk.data = message_bytes data_chunks.append(chunk) req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks) - await retry_transient_errors(self.client.stub.FunctionCallPutDataOut, req) + await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req) async def generator_output_task(self, function_call_id: str, data_format: int, message_rx: asyncio.Queue) -> None: """Task that feeds generator outputs into a function call's `data_out` stream.""" @@ -473,7 +473,7 @@ async def _generate_inputs( # If number of active inputs is at max queue size, this will block. iteration += 1 response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors( - self.client.stub.FunctionGetInputs, request + self._client.stub.FunctionGetInputs, request ) if response.rate_limit_sleep_duration: @@ -529,7 +529,7 @@ async def run_inputs_outputs( self._semaphore = asyncio.Semaphore(input_concurrency) async for inputs in self._generate_inputs(batch_max_size, batch_linger_ms): - io_context = await IOContext.create(self.client, finalized_functions, inputs, batch_max_size > 0) + io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0) self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time() yield io_context self.current_input_id, self.current_input_started_at = (None, None) @@ -540,7 +540,7 @@ async def run_inputs_outputs( @synchronizer.no_io_translation async def _push_outputs( - self, io_context: IOContext, started_at: float, data_format: int, results: List[Any] + self, io_context: IOContext, started_at: float, data_format: int, results: List[api_pb2.GenericResult] ) -> None: output_created_at = time.time() outputs = [ @@ -554,7 +554,7 @@ async def _push_outputs( for input_id, result in zip(io_context.input_ids, results) ] await retry_transient_errors( - self.client.stub.FunctionPutOutputs, + self._client.stub.FunctionPutOutputs, api_pb2.FunctionPutOutputsRequest(outputs=outputs), additional_status_codes=[Status.RESOURCE_EXHAUSTED], max_retries=None, # Retry indefinitely, trying every 1s. @@ -586,7 +586,7 @@ async def format_exception_results(self, io_context: IOContext, data: Any, **kwa async def format_output_results( self, io_context: IOContext, data: Any, data_format: int, **kwargs - ) -> api_pb2.GenericResult: + ) -> List[api_pb2.GenericResult]: data = io_context.validate_output_data(data) kwargs_list = await asyncio.gather( *[self.format_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) for d in data] @@ -622,7 +622,7 @@ async def handle_user_exception(self) -> AsyncGenerator[None, None]: ) req = api_pb2.TaskResultRequest(result=result) - await retry_transient_errors(self.client.stub.TaskResult, req) + await retry_transient_errors(self._client.stub.TaskResult, req) # Shut down the task gracefully raise UserException() @@ -757,7 +757,7 @@ async def memory_restore(self) -> None: "CUDA device availability may be inaccurate." ) - self.client = await _Client.from_env() + self._client = await _Client.from_env() async def memory_snapshot(self) -> None: """Message server indicating that function is ready to be checkpointed.""" @@ -771,11 +771,11 @@ async def memory_snapshot(self) -> None: self._waiting_for_memory_snapshot = True self._heartbeat_condition.notify_all() - await self.client.stub.ContainerCheckpoint( + await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) ) - await self.client._close(forget_credentials=True) + await self._client._close(forget_credentials=True) logger.debug("Memory snapshot request sent. Connection closed.") await self.memory_restore() @@ -796,7 +796,7 @@ async def volume_commit(self, volume_ids: List[str]) -> None: results = await asyncio.gather( *[ retry_transient_errors( - self.client.stub.VolumeCommit, + self._client.stub.VolumeCommit, api_pb2.VolumeCommitRequest(volume_id=v_id), max_retries=9, base_delay=0.25, @@ -834,7 +834,7 @@ async def interact(self): # todo(nathan): add warning if concurrency limit > 1. but idk how to check this here # todo(nathan): check if function interactivity is enabled try: - await self.client.stub.FunctionStartPtyShell(Empty()) + await self._client.stub.FunctionStartPtyShell(Empty()) except Exception as e: print("Error: Failed to start PTY shell.") raise e From 3d8976fc78b117d2cf03bee93ea65bf3d6fb7fee Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 8 Aug 2024 13:37:22 +0000 Subject: [PATCH 20/23] fix type --- modal/_container_io_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 8a11487ff..ba7c41cdd 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -81,7 +81,7 @@ async def create( assert len(inputs) >= 1 if is_batched else len(inputs) == 1 input_ids, function_call_ids, inputs = zip(*inputs) - async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput): + async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput: # If we got a pointer to a blob, download it from S3. if input.WhichOneof("args_oneof") == "args_blob_id": args = await blob_download(input.args_blob_id, client) From 9914aa68300c80983a49319fe3fc84192c6a65c8 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 8 Aug 2024 15:38:14 +0000 Subject: [PATCH 21/23] remote format results --- modal/_container_io_manager.py | 57 ++++++++++++++++------------------ 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index ba7c41cdd..6edc505fb 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -368,12 +368,12 @@ def deserialize(self, data: bytes) -> Any: def serialize_data_format(self, obj: Any, data_format: int) -> bytes: return serialize_data_format(obj, data_format) - async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: - if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await blob_upload(data, self._client.stub) - else: - kwargs["data"] = data - return kwargs + async def format_blob_data(self, data: bytes) -> Dict[str, Any]: + return ( + {"data_blob_id": await blob_upload(data, self._client.stub)} + if len(data) > MAX_OBJECT_SIZE_BYTES + else {"data": data} + ) async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]: """Read from the `data_in` stream of a function call.""" @@ -580,19 +580,6 @@ def serialize_traceback(self, exc: BaseException) -> Tuple[Optional[bytes], Opti return serialized_tb, tb_line_cache - async def format_exception_results(self, io_context: IOContext, data: Any, **kwargs) -> List[api_pb2.GenericResult]: - kwargs = await self.format_blob_data(self.serialize_exception(data), kwargs) - return [api_pb2.GenericResult(**kwargs) for _ in io_context.input_ids] - - async def format_output_results( - self, io_context: IOContext, data: Any, data_format: int, **kwargs - ) -> List[api_pb2.GenericResult]: - data = io_context.validate_output_data(data) - kwargs_list = await asyncio.gather( - *[self.format_blob_data(self.serialize_data_format(d, data_format), kwargs.copy()) for d in data] - ) - return [api_pb2.GenericResult(**kwargs) for kwargs in kwargs_list] - @asynccontextmanager async def handle_user_exception(self) -> AsyncGenerator[None, None]: """Sets the task as failed in a way where it's not retried. @@ -670,15 +657,17 @@ async def handle_input_exception( repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000] repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" - results = await self.format_exception_results( - io_context=io_context, - status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, - data=exc, - exception=repr_exc, - traceback=traceback.format_exc(), - serialized_tb=serialized_tb, - tb_line_cache=tb_line_cache, - ) + results = [ + api_pb2.GenericResult( + status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, + exception=repr_exc, + traceback=traceback.format_exc(), + serialized_tb=serialized_tb, + tb_line_cache=tb_line_cache, + **await self.format_blob_data(self.serialize_exception(exc)), + ) + for _ in io_context.input_ids + ] await self._push_outputs( io_context=io_context, started_at=started_at, @@ -694,9 +683,17 @@ async def exit_context(self, started_at): @synchronizer.no_io_translation async def push_outputs(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: - results = await self.format_output_results( - io_context, data, data_format, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + data = io_context.validate_output_data(data) + formatted_data = await asyncio.gather( + *[self.format_blob_data(self.serialize_data_format(d, data_format)) for d in data] ) + results = [ + api_pb2.GenericResult( + status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, + **d, + ) + for d in formatted_data + ] await self._push_outputs( io_context=io_context, started_at=started_at, From 637dca6dcebfea4e5d700e95c06c21d2677ab4cc Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 8 Aug 2024 16:05:27 +0000 Subject: [PATCH 22/23] add docstring --- modal/_container_io_manager.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 6edc505fb..a2b135a7b 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -52,6 +52,10 @@ class FinalizedFunction: class IOContext: + """Context object for managing input, function calls, and function executions + in a batched or single input context. + """ + input_ids: List[str] function_call_ids: List[str] finalized_function: FinalizedFunction @@ -146,16 +150,16 @@ def call_finalized_function(self) -> Any: return res def validate_output_data(self, data: Any) -> List[Any]: - if self._is_batched: - function_name = self.finalized_function.callable.__name__ - if not isinstance(data, list): - raise InvalidError(f"Output of batch function {function_name} must be a list.") - if len(data) != len(self.input_ids): - raise InvalidError( - f"Output of batch function {function_name} must be a list of the same length as its inputs." - ) - else: - data = [data] + if not self._is_batched: + return [data] + + function_name = self.finalized_function.callable.__name__ + if not isinstance(data, list): + raise InvalidError(f"Output of batch function {function_name} must be a list.") + if len(data) != len(self.input_ids): + raise InvalidError( + f"Output of batch function {function_name} must be a list of the same length as its inputs." + ) return data From 1b16994643f0c6784b974546aa199a20d23c8e93 Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Thu, 8 Aug 2024 18:36:23 +0000 Subject: [PATCH 23/23] error phrasing --- modal/_container_io_manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index a2b135a7b..4444964c5 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -121,7 +121,7 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: # check that all batched inputs should have the same number of args and kwargs if (num_params := len(args) + len(kwargs)) != len(param_names): raise InvalidError( - f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one function call in the batch has {num_params}.." # noqa + f"Modal batched function {func_name} takes {len(param_names)} positional arguments, but one invocation in the batch has {num_params}." # noqa ) for j, arg in enumerate(args): @@ -129,11 +129,11 @@ def _args_and_kwargs(self) -> Tuple[Tuple[Any, ...], Dict[str, List]]: for k, v in kwargs.items(): if k not in param_names: raise InvalidError( - f"Modal batched function {func_name} got unexpected keyword argument {k} in one function call in the batch." # noqa + f"Modal batched function {func_name} got unexpected keyword argument {k} in one invocation in the batch." # noqa ) if k in kwargs_by_inputs[i]: raise InvalidError( - f"Modal batched function {func_name} got multiple values for argument {k} in one function call in the batch." # noqa + f"Modal batched function {func_name} got multiple values for argument {k} in one invocation in the batch." # noqa ) kwargs_by_inputs[i][k] = v @@ -155,10 +155,10 @@ def validate_output_data(self, data: Any) -> List[Any]: function_name = self.finalized_function.callable.__name__ if not isinstance(data, list): - raise InvalidError(f"Output of batch function {function_name} must be a list.") + raise InvalidError(f"Output of batched function {function_name} must be a list.") if len(data) != len(self.input_ids): raise InvalidError( - f"Output of batch function {function_name} must be a list of the same length as its inputs." + f"Output of batched function {function_name} must be a list of equal length as its inputs." ) return data