Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process Batched inputs and outputs on client side #2064

Merged
merged 25 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 58 additions & 48 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
webhook_asgi_app,
wsgi_app_wrapper,
)
from ._container_io_manager import ContainerIOManager, LocalInput, UserException, _ContainerIOManager
from ._container_io_manager import ContainerIOManager, FinalizedFunction, IOContext, UserException, _ContainerIOManager
from ._proxy_tunnel import proxy_tunnel
from ._serialization import deserialize, deserialize_proto_params
from ._utils.async_utils import TaskContext, synchronizer
Expand Down Expand Up @@ -101,14 +101,6 @@ def construct_webhook_callable(
raise InvalidError(f"Unrecognized web endpoint type {webhook_config.type}")


@dataclass
class FinalizedFunction:
callable: Callable[..., Any]
is_async: bool
is_generator: bool
data_format: int # api_pb2.DataFormat


class Service(metaclass=ABCMeta):
"""Common interface for singular functions and class-based "services"

Expand Down Expand Up @@ -328,26 +320,26 @@ def call_function(
container_io_manager: "modal._container_io_manager.ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int,
batch_max_size: Optional[int],
batch_linger_ms: Optional[int],
):
async def run_input_async(finalized_function: FinalizedFunction, local_input: LocalInput) -> None:
async def run_input_async(io_context: IOContext) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id)
async with container_io_manager.handle_input_exception.aio(local_input.input_id, started_at):
logger.debug(f"Starting input {local_input.input_id} (async)")
res = finalized_function.callable(*local_input.args, **local_input.kwargs)
logger.debug(f"Finished input {local_input.input_id} (async)")

input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
reset_context = _set_current_context_ids(input_ids, function_call_ids)
async with container_io_manager.handle_input_exception.aio(io_context, started_at):
res = io_context.call_finalized_function()
# TODO(erikbern): any exception below shouldn't be considered a user exception
if finalized_function.is_generator:
if io_context.finalized_function.is_generator:
if not inspect.isasyncgen(res):
raise InvalidError(f"Async generator function returned value of type {type(res)}")

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
generator_output_task = asyncio.create_task(
container_io_manager.generator_output_task.aio(
local_input.function_call_id,
finalized_function.data_format,
function_call_ids[0],
io_context.finalized_function.data_format,
generator_queue,
)
)
Expand All @@ -360,8 +352,11 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo
await container_io_manager._queue_put.aio(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
await generator_output_task # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
await container_io_manager.push_output.aio(
local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
await container_io_manager.push_outputs.aio(
io_context,
started_at,
message,
api_pb2.DATA_FORMAT_GENERATOR_DONE,
)
else:
if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
Expand All @@ -370,29 +365,31 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo
" You might need to use @app.function(..., is_generator=True)."
)
value = await res
await container_io_manager.push_output.aio(
local_input.input_id, started_at, value, finalized_function.data_format
await container_io_manager.push_outputs.aio(
io_context,
started_at,
value,
io_context.finalized_function.data_format,
)
reset_context()

def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInput) -> None:
def run_input_sync(io_context: IOContext) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id)
with container_io_manager.handle_input_exception(local_input.input_id, started_at):
logger.debug(f"Starting input {local_input.input_id} (sync)")
res = finalized_function.callable(*local_input.args, **local_input.kwargs)
logger.debug(f"Finished input {local_input.input_id} (sync)")
input_ids, function_call_ids = io_context.input_ids, io_context.function_call_ids
reset_context = _set_current_context_ids(input_ids, function_call_ids)
with container_io_manager.handle_input_exception(io_context, started_at):
res = io_context.call_finalized_function()

# TODO(erikbern): any exception below shouldn't be considered a user exception
if finalized_function.is_generator:
if io_context.finalized_function.is_generator:
if not inspect.isgenerator(res):
raise InvalidError(f"Generator function returned value of type {type(res)}")

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
local_input.function_call_id,
finalized_function.data_format,
function_call_ids[0],
io_context.finalized_function.data_format,
generator_queue,
_future=True, # type: ignore # Synchronicity magic to return a future.
)
Expand All @@ -405,16 +402,16 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu
container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
generator_output_task.result() # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
container_io_manager.push_output(
local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
)
container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
else:
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
raise InvalidError(
f"Sync (non-generator) function return value of type {type(res)}."
" You might need to use @app.function(..., is_generator=True)."
)
container_io_manager.push_output(local_input.input_id, started_at, res, finalized_function.data_format)
container_io_manager.push_outputs(
io_context, started_at, res, io_context.finalized_function.data_format
)
reset_context()

if input_concurrency > 1:
Expand All @@ -425,24 +422,26 @@ async def run_concurrent_inputs():
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for local_input in container_io_manager.run_inputs_outputs.aio(input_concurrency):
finalized_function = finalized_functions[local_input.method_name]
async for io_context in container_io_manager.run_inputs_outputs.aio(
finalized_functions, input_concurrency, batch_max_size, batch_linger_ms
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
if finalized_function.is_async:
task_context.create_task(run_input_async(finalized_function, local_input))
if io_context.finalized_function.is_async:
task_context.create_task(run_input_async(io_context))
else:
# run sync input in thread
thread_pool.submit(run_input_sync, finalized_function, local_input)
thread_pool.submit(run_input_sync, io_context)

user_code_event_loop.run(run_concurrent_inputs())
else:
for local_input in container_io_manager.run_inputs_outputs(input_concurrency):
finalized_function = finalized_functions[local_input.method_name]
if finalized_function.is_async:
user_code_event_loop.run(run_input_async(finalized_function, local_input))
for io_context in container_io_manager.run_inputs_outputs(
finalized_functions, input_concurrency, batch_max_size, batch_linger_ms
):
if io_context.finalized_function.is_async:
user_code_event_loop.run(run_input_async(io_context))
else:
# Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
# during function execution. This is sent to cancel inputs from the user
Expand All @@ -453,7 +452,7 @@ def _cancel_input_signal_handler(signum, stackframe):
# run this sync code in the main thread, blocking the "userland" event loop
# this lets us cancel it using a signal handler that raises an exception
try:
run_input_sync(finalized_function, local_input)
run_input_sync(io_context)
finally:
signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler

Expand Down Expand Up @@ -738,10 +737,14 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency doesn't apply for `modal shell`.
# Concurrency and batching doesn't apply for `modal shell`.
input_concurrency = 1
batch_max_size = 0
batch_linger_ms = 0
else:
input_concurrency = function_def.allow_concurrent_inputs or 1
batch_max_size = function_def.batch_max_size or 0
batch_linger_ms = function_def.batch_linger_ms or 0

# Get ids and metadata for objects (primarily functions and classes) on the app
container_app: RunningApp = container_io_manager.get_app_objects()
Expand Down Expand Up @@ -803,7 +806,14 @@ def breakpoint_wrapper():

# Execute the function.
try:
call_function(event_loop, container_io_manager, finalized_functions, input_concurrency)
call_function(
event_loop,
container_io_manager,
finalized_functions,
input_concurrency,
batch_max_size,
batch_linger_ms,
)
finally:
# Run exit handlers. From this point onward, ignore all SIGINT signals that come from
# graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
Expand Down
Loading
Loading