From 69d701f810181112645e70e83756b29645843acb Mon Sep 17 00:00:00 2001 From: cathyzbn Date: Mon, 5 Aug 2024 14:12:05 +0000 Subject: [PATCH] fix type checck --- modal/_container_entrypoint.py | 3 +- modal/_container_io_manager.py | 82 +++++++++++++++++----------------- 2 files changed, 43 insertions(+), 42 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index c390418cc..84473f79d 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -18,7 +18,6 @@ from google.protobuf.message import Message from synchronicity import Interface -from modal._container_io_manager import FinalizedFunction, IOContext from modal_proto import api_pb2 from ._asgi import ( @@ -29,7 +28,7 @@ webhook_asgi_app, wsgi_app_wrapper, ) -from ._container_io_manager import ContainerIOManager, UserException, _ContainerIOManager +from ._container_io_manager import ContainerIOManager, FinalizedFunction, IOContext, UserException, _ContainerIOManager from ._proxy_tunnel import proxy_tunnel from ._serialization import deserialize, deserialize_proto_params from ._utils.async_utils import TaskContext, synchronizer diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 54a5ab37c..051157df1 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -19,7 +19,7 @@ from modal_proto import api_pb2 -from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format +from ._serialization import deserialize, serialize, serialize_data_format from ._traceback import extract_traceback from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload @@ -56,42 +56,42 @@ class IOContext: function_call_ids: List[str] finalized_function: FinalizedFunction + @classmethod async def create( + cls, container_io_manager: "_ContainerIOManager", finalized_functions: Dict[str, FinalizedFunction], inputs: List[Tuple[str, str, api_pb2.FunctionInput]], is_batched: bool, - ) -> None: - self = IOContext() - assert len(inputs) > 0 - self.input_ids, self.function_call_ids, self.inputs = zip(*inputs) - self.is_batched = is_batched - - self.inputs = await asyncio.gather(*[self.populate_input_blobs(input) for input in self.inputs]) + ) -> "IOContext": + self = cls.__new__(cls) + assert len(inputs) >= 1 if is_batched else len(inputs) == 1 + self.input_ids, self.function_call_ids, inputs = zip(*inputs) + self.inputs = await asyncio.gather( + *[self._populate_input_blobs(container_io_manager, input) for input in inputs] + ) # check every input in batch executes the same function method_name = self.inputs[0].method_name assert all(method_name == input.method_name for input in self.inputs) self.finalized_function = finalized_functions[method_name] self.deserialized_args = [ - container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in self.inputs + container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs ] + self.is_batched = is_batched return self - async def populate_input_blobs(self, input: api_pb2.FunctionInput): + async def _populate_input_blobs(self, container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): # If we got a pointer to a blob, download it from S3. if input.WhichOneof("args_oneof") == "args_blob_id": - args = await blob_download(input.args_blob_id, self._client.stub) - + args = await container_io_manager.blob_download(input.args_blob_id) # Mutating input.ClearField("args_blob_id") input.args = args - return input return input def _args_and_kwargs(self): if not self.is_batched: - assert len(self.inputs) == 1 return self.deserialized_args[0] func_name = self.finalized_function.callable.__name__ @@ -141,38 +141,27 @@ def call_finalized_function(self) -> Any: logger.debug(f"Finished input {self.input_ids} (async)") return res - def serialize_data_format(self, obj: Any, data_format: int) -> bytes: - return serialize_data_format(obj, data_format) - - def deserialize_data_format(self, data: bytes, data_format: int) -> Any: - return deserialize_data_format(data, data_format, self._client) - - async def _format_data(self, data: bytes, kwargs: Dict[str, Any], blob_func: Callable) -> Dict[str, Any]: - if len(data) > MAX_OBJECT_SIZE_BYTES: - kwargs["data_blob_id"] = await blob_func(data) - else: - kwargs["data"] = data - return kwargs - @synchronizer.no_io_translation - async def format_output( - self, started_at: float, data_format: int, blob_func: Callable, **kwargs + async def format_outputs( + self, container_io_manager: "_ContainerIOManager", started_at: float, data_format: int, **kwargs ) -> List[api_pb2.FunctionPutOutputsItem]: if "data" not in kwargs: kwargs_list = [kwargs] * len(self.input_ids) # data is not batched, return a single kwargs. - elif not self.is_batched and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: - data = self.serialize_data_format(kwargs.pop("data"), data_format) - kwargs_list = [await self._format_data(data, kwargs, blob_func)] - elif not self.is_batched: # data is not batched and is an exception - kwargs_list = [await self._format_data(kwargs.pop("data"), kwargs, blob_func)] + elif not self.is_batched: + data = ( + serialize_data_format(kwargs.pop("data"), data_format) + if kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + else kwargs.pop("data") + ) + kwargs_list = [await container_io_manager.format_blob_data(data, kwargs)] # data is batched, return a list of kwargs # split the list of data in kwargs to respective input_ids and report error for every input_id in batch call. elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: error_data = kwargs.pop("data") kwargs_list = await asyncio.gather( - *[self._format_data(error_data, kwargs, blob_func) for _ in self.input_ids] + *[container_io_manager.format_blob_data(error_data, kwargs) for _ in self.input_ids] ) else: function_name = self.finalized_function.callable.__name__ @@ -184,7 +173,10 @@ async def format_output( f"Output of batch function {function_name} must be a list of the same length as its inputs." ) kwargs_list = await asyncio.gather( - *[self._format_data(self.serialize_data_format(d, data_format), kwargs.copy(), blob_func) for d in data] + *[ + container_io_manager.format_blob_data(serialize_data_format(d, data_format), kwargs.copy()) + for d in data + ] ) return [ @@ -407,6 +399,16 @@ def deserialize(self, data: bytes) -> Any: async def blob_upload(self, data: bytes) -> str: return await blob_upload(data, self._client.stub) + async def blob_download(self, blob_id: str) -> bytes: + return await blob_download(blob_id, self._client.stub) + + async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]: + if len(data) > MAX_OBJECT_SIZE_BYTES: + kwargs["data_blob_id"] = await self.blob_upload(data) + else: + kwargs["data"] = data + return kwargs + async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]: """Read from the `data_in` stream of a function call.""" async for data in _stream_function_call_data(self._client, function_call_id, "data_in"): @@ -671,10 +673,10 @@ async def handle_input_exception( repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000] repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception" - outputs = await io_context.format_output( + outputs = await io_context.format_outputs( + container_io_manager=self, started_at=started_at, data_format=api_pb2.DATA_FORMAT_PICKLE, - blob_func=self.blob_upload, status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE, data=self.serialize_exception(exc), exception=repr_exc, @@ -692,10 +694,10 @@ async def complete_call(self, started_at): @synchronizer.no_io_translation async def push_output(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None: - outputs = await io_context.format_output( + outputs = await io_context.format_outputs( + container_io_manager=self, started_at=started_at, data_format=data_format, - blob_func=self.blob_upload, data=data, status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS, )