Skip to content

Commit

Permalink
Process Batched inputs and outputs on client side (#2064)
Browse files Browse the repository at this point in the history
* process batched inputs on client side

* upload blob data without blocking

* reduce duplicate code

* local input change

* refactor

* cleanup IOContest init

* fix synchronize

* semantics change

* function name change

* cleanup push outputs

* remove blob download

* reword errors

* revert _client

* remote format results

* add docstring

* error phrasing
  • Loading branch information
cathyzbn committed Aug 8, 2024
1 parent a6887e4 commit fbe6b5b
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 136 deletions.
106 changes: 58 additions & 48 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
webhook_asgi_app,
wsgi_app_wrapper,
)
from ._container_io_manager import ContainerIOManager, LocalInput, UserException, _ContainerIOManager
from ._container_io_manager import ContainerIOManager, FinalizedFunction, IOContext, UserException, _ContainerIOManager
from ._proxy_tunnel import proxy_tunnel
from ._serialization import deserialize, deserialize_proto_params
from ._utils.async_utils import TaskContext, synchronizer
Expand Down Expand Up @@ -101,14 +101,6 @@ def construct_webhook_callable(
raise InvalidError(f"Unrecognized web endpoint type {webhook_config.type}")


@dataclass
class FinalizedFunction:
callable: Callable[..., Any]
is_async: bool
is_generator: bool
data_format: int # api_pb2.DataFormat


class Service(metaclass=ABCMeta):
"""Common interface for singular functions and class-based "services"
Expand Down Expand Up @@ -328,26 +320,26 @@ def call_function(
container_io_manager: "modal._container_io_manager.ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int,
batch_max_size: Optional[int],
batch_linger_ms: Optional[int],
):
async def run_input_async(finalized_function: FinalizedFunction, local_input: LocalInput) -> None:
async def run_input_async(io_context: IOContext) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id)
async with container_io_manager.handle_input_exception.aio(local_input.input_id, started_at):
logger.debug(f"Starting input {local_input.input_id} (async)")
res = finalized_function.callable(*local_input.args, **local_input.kwargs)
logger.debug(f"Finished input {local_input.input_id} (async)")

input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
reset_context = _set_current_context_ids(input_ids, function_call_ids)
async with container_io_manager.handle_input_exception.aio(io_context, started_at):
res = io_context.call_finalized_function()
# TODO(erikbern): any exception below shouldn't be considered a user exception
if finalized_function.is_generator:
if io_context.finalized_function.is_generator:
if not inspect.isasyncgen(res):
raise InvalidError(f"Async generator function returned value of type {type(res)}")

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
generator_output_task = asyncio.create_task(
container_io_manager.generator_output_task.aio(
local_input.function_call_id,
finalized_function.data_format,
function_call_ids[0],
io_context.finalized_function.data_format,
generator_queue,
)
)
Expand All @@ -360,8 +352,11 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo
await container_io_manager._queue_put.aio(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
await generator_output_task # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
await container_io_manager.push_output.aio(
local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
await container_io_manager.push_outputs.aio(
io_context,
started_at,
message,
api_pb2.DATA_FORMAT_GENERATOR_DONE,
)
else:
if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
Expand All @@ -370,29 +365,31 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo
" You might need to use @app.function(..., is_generator=True)."
)
value = await res
await container_io_manager.push_output.aio(
local_input.input_id, started_at, value, finalized_function.data_format
await container_io_manager.push_outputs.aio(
io_context,
started_at,
value,
io_context.finalized_function.data_format,
)
reset_context()

def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInput) -> None:
def run_input_sync(io_context: IOContext) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id)
with container_io_manager.handle_input_exception(local_input.input_id, started_at):
logger.debug(f"Starting input {local_input.input_id} (sync)")
res = finalized_function.callable(*local_input.args, **local_input.kwargs)
logger.debug(f"Finished input {local_input.input_id} (sync)")
input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
reset_context = _set_current_context_ids(input_ids, function_call_ids)
with container_io_manager.handle_input_exception(io_context, started_at):
res = io_context.call_finalized_function()

# TODO(erikbern): any exception below shouldn't be considered a user exception
if finalized_function.is_generator:
if io_context.finalized_function.is_generator:
if not inspect.isgenerator(res):
raise InvalidError(f"Generator function returned value of type {type(res)}")

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
local_input.function_call_id,
finalized_function.data_format,
function_call_ids[0],
io_context.finalized_function.data_format,
generator_queue,
_future=True, # type: ignore # Synchronicity magic to return a future.
)
Expand All @@ -405,16 +402,16 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu
container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
generator_output_task.result() # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
container_io_manager.push_output(
local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
)
container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
else:
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
raise InvalidError(
f"Sync (non-generator) function return value of type {type(res)}."
" You might need to use @app.function(..., is_generator=True)."
)
container_io_manager.push_output(local_input.input_id, started_at, res, finalized_function.data_format)
container_io_manager.push_outputs(
io_context, started_at, res, io_context.finalized_function.data_format
)
reset_context()

if input_concurrency > 1:
Expand All @@ -425,24 +422,26 @@ async def run_concurrent_inputs():
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for local_input in container_io_manager.run_inputs_outputs.aio(input_concurrency):
finalized_function = finalized_functions[local_input.method_name]
async for io_context in container_io_manager.run_inputs_outputs.aio(
finalized_functions, input_concurrency, batch_max_size, batch_linger_ms
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
if finalized_function.is_async:
task_context.create_task(run_input_async(finalized_function, local_input))
if io_context.finalized_function.is_async:
task_context.create_task(run_input_async(io_context))
else:
# run sync input in thread
thread_pool.submit(run_input_sync, finalized_function, local_input)
thread_pool.submit(run_input_sync, io_context)

user_code_event_loop.run(run_concurrent_inputs())
else:
for local_input in container_io_manager.run_inputs_outputs(input_concurrency):
finalized_function = finalized_functions[local_input.method_name]
if finalized_function.is_async:
user_code_event_loop.run(run_input_async(finalized_function, local_input))
for io_context in container_io_manager.run_inputs_outputs(
finalized_functions, input_concurrency, batch_max_size, batch_linger_ms
):
if io_context.finalized_function.is_async:
user_code_event_loop.run(run_input_async(io_context))
else:
# Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
# during function execution. This is sent to cancel inputs from the user
Expand All @@ -453,7 +452,7 @@ def _cancel_input_signal_handler(signum, stackframe):
# run this sync code in the main thread, blocking the "userland" event loop
# this lets us cancel it using a signal handler that raises an exception
try:
run_input_sync(finalized_function, local_input)
run_input_sync(io_context)
finally:
signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler

Expand Down Expand Up @@ -738,10 +737,14 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency doesn't apply for `modal shell`.
# Concurrency and batching doesn't apply for `modal shell`.
input_concurrency = 1
batch_max_size = 0
batch_linger_ms = 0
else:
input_concurrency = function_def.allow_concurrent_inputs or 1
batch_max_size = function_def.batch_max_size or 0
batch_linger_ms = function_def.batch_linger_ms or 0

# Get ids and metadata for objects (primarily functions and classes) on the app
container_app: RunningApp = container_io_manager.get_app_objects()
Expand Down Expand Up @@ -803,7 +806,14 @@ def breakpoint_wrapper():

# Execute the function.
try:
call_function(event_loop, container_io_manager, finalized_functions, input_concurrency)
call_function(
event_loop,
container_io_manager,
finalized_functions,
input_concurrency,
batch_max_size,
batch_linger_ms,
)
finally:
# Run exit handlers. From this point onward, ignore all SIGINT signals that come from
# graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
Expand Down
Loading

0 comments on commit fbe6b5b

Please sign in to comment.