diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 051157df1..4007b8cde 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -56,6 +56,20 @@ class IOContext: function_call_ids: List[str] finalized_function: FinalizedFunction + def __init__( + self, + input_ids: List[str], + function_call_ids: List[str], + finalized_function: FinalizedFunction, + deserialized_args: List, + is_batched: bool, + ): + self.input_ids = input_ids + self.function_call_ids = function_call_ids + self.finalized_function = finalized_function + self.deserialized_args = deserialized_args + self.is_batched = is_batched + @classmethod async def create( cls, @@ -64,31 +78,28 @@ async def create( inputs: List[Tuple[str, str, api_pb2.FunctionInput]], is_batched: bool, ) -> "IOContext": - self = cls.__new__(cls) assert len(inputs) >= 1 if is_batched else len(inputs) == 1 - self.input_ids, self.function_call_ids, inputs = zip(*inputs) - self.inputs = await asyncio.gather( - *[self._populate_input_blobs(container_io_manager, input) for input in inputs] - ) + input_ids, function_call_ids, inputs = zip(*inputs) + + async def _populate_input_blobs(container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): + # If we got a pointer to a blob, download it from S3. + if input.WhichOneof("args_oneof") == "args_blob_id": + args = await container_io_manager.blob_download(input.args_blob_id) + # Mutating + input.ClearField("args_blob_id") + input.args = args + + return input + + inputs = await asyncio.gather(*[_populate_input_blobs(container_io_manager, input) for input in inputs]) # check every input in batch executes the same function - method_name = self.inputs[0].method_name - assert all(method_name == input.method_name for input in self.inputs) - self.finalized_function = finalized_functions[method_name] - self.deserialized_args = [ + method_name = inputs[0].method_name + assert all(method_name == input.method_name for input in inputs) + finalized_function = finalized_functions[method_name] + deserialized_args = [ container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs ] - self.is_batched = is_batched - return self - - async def _populate_input_blobs(self, container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput): - # If we got a pointer to a blob, download it from S3. - if input.WhichOneof("args_oneof") == "args_blob_id": - args = await container_io_manager.blob_download(input.args_blob_id) - # Mutating - input.ClearField("args_blob_id") - input.args = args - - return input + return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched) def _args_and_kwargs(self): if not self.is_batched: