Skip to content

Commit

Permalink
fix type checck
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Aug 5, 2024
1 parent 3b8c11f commit 69d701f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 42 deletions.
3 changes: 1 addition & 2 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from google.protobuf.message import Message
from synchronicity import Interface

from modal._container_io_manager import FinalizedFunction, IOContext
from modal_proto import api_pb2

from ._asgi import (
Expand All @@ -29,7 +28,7 @@
webhook_asgi_app,
wsgi_app_wrapper,
)
from ._container_io_manager import ContainerIOManager, UserException, _ContainerIOManager
from ._container_io_manager import ContainerIOManager, FinalizedFunction, IOContext, UserException, _ContainerIOManager
from ._proxy_tunnel import proxy_tunnel
from ._serialization import deserialize, deserialize_proto_params
from ._utils.async_utils import TaskContext, synchronizer
Expand Down
82 changes: 42 additions & 40 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from modal_proto import api_pb2

from ._serialization import deserialize, deserialize_data_format, serialize, serialize_data_format
from ._serialization import deserialize, serialize, serialize_data_format
from ._traceback import extract_traceback
from ._utils.async_utils import TaskContext, asyncify, synchronize_api, synchronizer
from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
Expand Down Expand Up @@ -56,42 +56,42 @@ class IOContext:
function_call_ids: List[str]
finalized_function: FinalizedFunction

@classmethod
async def create(
cls,
container_io_manager: "_ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
inputs: List[Tuple[str, str, api_pb2.FunctionInput]],
is_batched: bool,
) -> None:
self = IOContext()
assert len(inputs) > 0
self.input_ids, self.function_call_ids, self.inputs = zip(*inputs)
self.is_batched = is_batched

self.inputs = await asyncio.gather(*[self.populate_input_blobs(input) for input in self.inputs])
) -> "IOContext":
self = cls.__new__(cls)
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
self.input_ids, self.function_call_ids, inputs = zip(*inputs)
self.inputs = await asyncio.gather(
*[self._populate_input_blobs(container_io_manager, input) for input in inputs]
)
# check every input in batch executes the same function
method_name = self.inputs[0].method_name
assert all(method_name == input.method_name for input in self.inputs)
self.finalized_function = finalized_functions[method_name]
self.deserialized_args = [
container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in self.inputs
container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs
]
self.is_batched = is_batched
return self

async def populate_input_blobs(self, input: api_pb2.FunctionInput):
async def _populate_input_blobs(self, container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput):
# If we got a pointer to a blob, download it from S3.
if input.WhichOneof("args_oneof") == "args_blob_id":
args = await blob_download(input.args_blob_id, self._client.stub)

args = await container_io_manager.blob_download(input.args_blob_id)
# Mutating
input.ClearField("args_blob_id")
input.args = args
return input

return input

def _args_and_kwargs(self):
if not self.is_batched:
assert len(self.inputs) == 1
return self.deserialized_args[0]

func_name = self.finalized_function.callable.__name__
Expand Down Expand Up @@ -141,38 +141,27 @@ def call_finalized_function(self) -> Any:
logger.debug(f"Finished input {self.input_ids} (async)")
return res

def serialize_data_format(self, obj: Any, data_format: int) -> bytes:
return serialize_data_format(obj, data_format)

def deserialize_data_format(self, data: bytes, data_format: int) -> Any:
return deserialize_data_format(data, data_format, self._client)

async def _format_data(self, data: bytes, kwargs: Dict[str, Any], blob_func: Callable) -> Dict[str, Any]:
if len(data) > MAX_OBJECT_SIZE_BYTES:
kwargs["data_blob_id"] = await blob_func(data)
else:
kwargs["data"] = data
return kwargs

@synchronizer.no_io_translation
async def format_output(
self, started_at: float, data_format: int, blob_func: Callable, **kwargs
async def format_outputs(
self, container_io_manager: "_ContainerIOManager", started_at: float, data_format: int, **kwargs
) -> List[api_pb2.FunctionPutOutputsItem]:
if "data" not in kwargs:
kwargs_list = [kwargs] * len(self.input_ids)
# data is not batched, return a single kwargs.
elif not self.is_batched and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
data = self.serialize_data_format(kwargs.pop("data"), data_format)
kwargs_list = [await self._format_data(data, kwargs, blob_func)]
elif not self.is_batched: # data is not batched and is an exception
kwargs_list = [await self._format_data(kwargs.pop("data"), kwargs, blob_func)]
elif not self.is_batched:
data = (
serialize_data_format(kwargs.pop("data"), data_format)
if kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
else kwargs.pop("data")
)
kwargs_list = [await container_io_manager.format_blob_data(data, kwargs)]

# data is batched, return a list of kwargs
# split the list of data in kwargs to respective input_ids and report error for every input_id in batch call.
elif "status" in kwargs and kwargs["status"] == api_pb2.GenericResult.GENERIC_STATUS_FAILURE:
error_data = kwargs.pop("data")
kwargs_list = await asyncio.gather(
*[self._format_data(error_data, kwargs, blob_func) for _ in self.input_ids]
*[container_io_manager.format_blob_data(error_data, kwargs) for _ in self.input_ids]
)
else:
function_name = self.finalized_function.callable.__name__
Expand All @@ -184,7 +173,10 @@ async def format_output(
f"Output of batch function {function_name} must be a list of the same length as its inputs."
)
kwargs_list = await asyncio.gather(
*[self._format_data(self.serialize_data_format(d, data_format), kwargs.copy(), blob_func) for d in data]
*[
container_io_manager.format_blob_data(serialize_data_format(d, data_format), kwargs.copy())
for d in data
]
)

return [
Expand Down Expand Up @@ -407,6 +399,16 @@ def deserialize(self, data: bytes) -> Any:
async def blob_upload(self, data: bytes) -> str:
return await blob_upload(data, self._client.stub)

async def blob_download(self, blob_id: str) -> bytes:
return await blob_download(blob_id, self._client.stub)

async def format_blob_data(self, data: bytes, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if len(data) > MAX_OBJECT_SIZE_BYTES:
kwargs["data_blob_id"] = await self.blob_upload(data)
else:
kwargs["data"] = data
return kwargs

async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
"""Read from the `data_in` stream of a function call."""
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
Expand Down Expand Up @@ -671,10 +673,10 @@ async def handle_input_exception(
repr_exc = repr_exc[: MAX_OBJECT_SIZE_BYTES - 1000]
repr_exc = f"{repr_exc}...\nTrimmed {trimmed_bytes} bytes from original exception"

outputs = await io_context.format_output(
outputs = await io_context.format_outputs(
container_io_manager=self,
started_at=started_at,
data_format=api_pb2.DATA_FORMAT_PICKLE,
blob_func=self.blob_upload,
status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE,
data=self.serialize_exception(exc),
exception=repr_exc,
Expand All @@ -692,10 +694,10 @@ async def complete_call(self, started_at):

@synchronizer.no_io_translation
async def push_output(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None:
outputs = await io_context.format_output(
outputs = await io_context.format_outputs(
container_io_manager=self,
started_at=started_at,
data_format=data_format,
blob_func=self.blob_upload,
data=data,
status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS,
)
Expand Down

0 comments on commit 69d701f

Please sign in to comment.