Skip to content

Commit

Permalink
cleanup IOContest init
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Aug 5, 2024
1 parent 69d701f commit cdb4ef9
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ class IOContext:
function_call_ids: List[str]
finalized_function: FinalizedFunction

def __init__(
self,
input_ids: List[str],
function_call_ids: List[str],
finalized_function: FinalizedFunction,
deserialized_args: List,
is_batched: bool,
):
self.input_ids = input_ids
self.function_call_ids = function_call_ids
self.finalized_function = finalized_function
self.deserialized_args = deserialized_args
self.is_batched = is_batched

@classmethod
async def create(
cls,
Expand All @@ -64,31 +78,28 @@ async def create(
inputs: List[Tuple[str, str, api_pb2.FunctionInput]],
is_batched: bool,
) -> "IOContext":
self = cls.__new__(cls)
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
self.input_ids, self.function_call_ids, inputs = zip(*inputs)
self.inputs = await asyncio.gather(
*[self._populate_input_blobs(container_io_manager, input) for input in inputs]
)
input_ids, function_call_ids, inputs = zip(*inputs)

async def _populate_input_blobs(container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput):
# If we got a pointer to a blob, download it from S3.
if input.WhichOneof("args_oneof") == "args_blob_id":
args = await container_io_manager.blob_download(input.args_blob_id)
# Mutating
input.ClearField("args_blob_id")
input.args = args

return input

inputs = await asyncio.gather(*[_populate_input_blobs(container_io_manager, input) for input in inputs])
# check every input in batch executes the same function
method_name = self.inputs[0].method_name
assert all(method_name == input.method_name for input in self.inputs)
self.finalized_function = finalized_functions[method_name]
self.deserialized_args = [
method_name = inputs[0].method_name
assert all(method_name == input.method_name for input in inputs)
finalized_function = finalized_functions[method_name]
deserialized_args = [
container_io_manager.deserialize(input.args) if input.args else ((), {}) for input in inputs
]
self.is_batched = is_batched
return self

async def _populate_input_blobs(self, container_io_manager: "_ContainerIOManager", input: api_pb2.FunctionInput):
# If we got a pointer to a blob, download it from S3.
if input.WhichOneof("args_oneof") == "args_blob_id":
args = await container_io_manager.blob_download(input.args_blob_id)
# Mutating
input.ClearField("args_blob_id")
input.args = args

return input
return cls(input_ids, function_call_ids, finalized_function, deserialized_args, is_batched)

def _args_and_kwargs(self):
if not self.is_batched:
Expand Down

0 comments on commit cdb4ef9

Please sign in to comment.