diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 626ab5432..d3c8e34f9 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -47,23 +47,8 @@ class LocalInput: input_id: str function_call_id: str method_name: str - args: Tuple[Any, ...] - kwargs: Dict[str, Any] - - def __init__( - self, - container_io_manager: "_ContainerIOManager", - input_id: str, - function_call_id: str, - input_pb: api_pb2.FunctionInput, - ): - self.input_id = input_id - self.function_call_id = function_call_id - self.method_name = input_pb.method_name - self.args, self.kwargs = container_io_manager.deserialize(input_pb.args) if input_pb.args else ((), {}) - - container_io_manager.current_input_id = input_id - container_io_manager.current_input_started_at = time.time() + args: Any + kwargs: Any class _ContainerIOManager: @@ -457,7 +442,9 @@ async def run_inputs_outputs( async for inputs_list in self._generate_inputs(): local_inputs_list = [] for input_id, function_call_id, input_pb in inputs_list: - local_inputs_list.append(LocalInput(self, input_id, function_call_id, input_pb)) + args, kwargs = self.deserialize(input_pb.args) if input_pb.args else ((), {}) + self.current_input_id, self.current_input_started_at = (input_id, time.time()) + local_inputs_list.append(LocalInput(input_id, function_call_id, input_pb.method_name, args, kwargs)) yield local_inputs_list self.current_input_id, self.current_input_started_at = (None, None)