Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User interface for @batched #2065

Merged
merged 33 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dc66405
initial
cathyzbn Jul 25, 2024
8232960
Merge branch 'main' into cathy/batching-integration
cathyzbn Jul 25, 2024
bd80e7d
enable class method to batch
cathyzbn Jul 26, 2024
2b8906b
fix current_id and deserialization latency
cathyzbn Jul 29, 2024
9a83ef5
Merge branch 'main' into cathy/batching-integration
cathyzbn Jul 29, 2024
c6d8ff0
cleanup and unit test
cathyzbn Jul 31, 2024
3df6fb4
Merge branch 'cathy/batching-integration' of github.com:modal-labs/mo…
cathyzbn Jul 31, 2024
96d7f2a
fix type check
cathyzbn Jul 31, 2024
c71d982
fix type check
cathyzbn Jul 31, 2024
3ebf9cb
Merge branch 'main' into cathy/batching-integration
cathyzbn Jul 31, 2024
fd50bd1
Merge branch 'main' into cathy/batching-integration
cathyzbn Jul 31, 2024
98b6146
Merge branch 'main' into cathy/batching-integration
cathyzbn Jul 31, 2024
15be068
fix test on linux
cathyzbn Jul 31, 2024
edd94ce
Merge branch 'cathy/batching-integration' of github.com:modal-labs/mo…
cathyzbn Jul 31, 2024
7d9896e
process batched inputs on client side
cathyzbn Jul 31, 2024
fc206a4
Merge branch 'main' into cathy/batch_input_output
cathyzbn Jul 31, 2024
55af973
isolate input/output change
cathyzbn Jul 31, 2024
c1cec07
upload blob data without blocking
cathyzbn Aug 1, 2024
651f838
fix type check
cathyzbn Aug 1, 2024
b4f09d8
cleanup
cathyzbn Aug 1, 2024
853c37b
fix type check
cathyzbn Aug 1, 2024
ecad424
Merge branch 'cathy/batching-integration' of github.com:modal-labs/mo…
cathyzbn Aug 1, 2024
e4b7b79
update function names
cathyzbn Aug 1, 2024
dfe14f3
Merge branch 'main' into cathy/batch_user_interface
cathyzbn Aug 8, 2024
368a239
Merge branch 'main' into cathy/batch_user_interface
cathyzbn Aug 8, 2024
c9b42d6
change param name + add errors check
cathyzbn Aug 8, 2024
17f4e91
nits
cathyzbn Aug 8, 2024
add2fa9
type check
cathyzbn Aug 8, 2024
31033ee
fix test
cathyzbn Aug 8, 2024
e80d937
fix test
cathyzbn Aug 8, 2024
88a7f03
Merge branch 'main' into cathy/batch_user_interface
cathyzbn Aug 12, 2024
91c1c83
Update container_test.py
cathyzbn Aug 12, 2024
0700ea7
add docstring
cathyzbn Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .image import Image
from .mount import Mount
from .network_file_system import NetworkFileSystem
from .partial_function import asgi_app, build, enter, exit, method, web_endpoint, web_server, wsgi_app
from .partial_function import asgi_app, batched, build, enter, exit, method, web_endpoint, web_server, wsgi_app
from .proxy import Proxy
from .queue import Queue
from .retries import Retries
Expand Down Expand Up @@ -64,6 +64,7 @@
"Tunnel",
"Volume",
"asgi_app",
"batched",
"build",
"current_function_call_id",
"current_input_id",
Expand Down
159 changes: 130 additions & 29 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import typing
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from google.protobuf.message import Message
from synchronicity import Interface
Expand Down Expand Up @@ -323,30 +323,87 @@ def _sigint_handler():
self.loop.remove_signal_handler(signal.SIGINT)


def _aggregate_args_and_kwargs(
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
local_inputs: Union[LocalInput, List[LocalInput]],
callable: Callable[..., Any],
) -> Tuple[Union[str, List[str]], Union[str, List[str]], Tuple[Any, ...], Dict[str, Any]]:
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(local_inputs, LocalInput):
return local_inputs.input_id, local_inputs.function_call_id, local_inputs.args, local_inputs.kwargs

# aggregate args and kwargs for batched input
input_ids = [local_input.input_id for local_input in local_inputs]
function_call_ids = [local_input.function_call_id for local_input in local_inputs]
param_names = list(inspect.signature(callable).parameters.keys())
for param in inspect.signature(callable).parameters.values():
if param.default is not inspect.Parameter.empty:
raise InvalidError(f"Modal batched function {callable.__name__} does not accept default arguments.")
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved

args_by_inputs: List[Dict[str, Any]] = [{} for _ in range(len(local_inputs))]
for i, local_input in enumerate(local_inputs):
params_len = len(local_input.args) + len(local_input.kwargs)
if params_len != len(param_names):
raise InvalidError(
f"Modal batched function {callable.__name__} takes {len(param_names)} positional arguments, but one call has {params_len}." # noqa
)
for j, arg in enumerate(local_input.args):
args_by_inputs[i][param_names[j]] = arg
for k, v in local_input.kwargs.items():
if k not in param_names:
raise InvalidError(
f"Modal batched function {callable.__name__} got an unexpected keyword argument {k} in one call."
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
)
if k in args_by_inputs[i]:
raise InvalidError(
f"Modal batched function {callable.__name__} got multiple values for argument {k} in one call."
)
args_by_inputs[i][k] = v

# put all arg / args into a kwargs dict, with items being param name -> list of values of length len(local_inputs)
formatted_kwargs = {
param_name: [args_by_inputs[i][param_name] for i in range(len(local_inputs))] for param_name in param_names
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
}
return input_ids, function_call_ids, tuple(), formatted_kwargs


def call_function(
user_code_event_loop: UserCodeEventLoop,
container_io_manager: "modal._container_io_manager.ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int,
batch_max_size: Optional[int],
batch_wait_ms: Optional[int],
):
async def run_input_async(finalized_function: FinalizedFunction, local_input: LocalInput) -> None:
async def run_input_async(
finalized_function: FinalizedFunction,
local_inputs: Union[LocalInput, List[LocalInput]],
container_io_manager: "modal._container_io_manager.ContainerIOManager",
) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id)
async with container_io_manager.handle_input_exception.aio(local_input.input_id, started_at):
logger.debug(f"Starting input {local_input.input_id} (async)")
res = finalized_function.callable(*local_input.args, **local_input.kwargs)
logger.debug(f"Finished input {local_input.input_id} (async)")
input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs(
local_inputs, finalized_function.callable
)
reset_context = _set_current_context_ids(input_ids, function_call_ids)
async with container_io_manager.handle_input_exception.aio(
input_ids, started_at, finalized_function.callable.__name__
):
logger.debug(f"Starting input {input_ids} (async)")
res = finalized_function.callable(*args, **kwargs)
logger.debug(f"Finished input {input_ids} (async)")

# TODO(erikbern): any exception below shouldn't be considered a user exception
if finalized_function.is_generator:
if not inspect.isasyncgen(res):
raise InvalidError(f"Async generator function returned value of type {type(res)}")
if isinstance(input_ids, list) or isinstance(function_call_ids, list):
raise InvalidError(
f"Batch function {finalized_function.callable.__name__} cannot return generators."
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
)

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
generator_output_task = asyncio.create_task(
container_io_manager.generator_output_task.aio(
local_input.function_call_id,
function_call_ids,
finalized_function.data_format,
generator_queue,
)
Expand All @@ -361,7 +418,11 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo
await generator_output_task # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
await container_io_manager.push_output.aio(
local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
input_ids,
started_at,
finalized_function.callable.__name__,
message,
api_pb2.DATA_FORMAT_GENERATOR_DONE,
)
else:
if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
Expand All @@ -371,27 +432,38 @@ async def run_input_async(finalized_function: FinalizedFunction, local_input: Lo
)
value = await res
await container_io_manager.push_output.aio(
local_input.input_id, started_at, value, finalized_function.data_format
input_ids, started_at, finalized_function.callable.__name__, value, finalized_function.data_format
)
reset_context()

def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInput) -> None:
def run_input_sync(
finalized_function: FinalizedFunction,
local_inputs: Union[LocalInput, List[LocalInput]],
container_io_manager: "modal._container_io_manager.ContainerIOManager",
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(local_input.input_id, local_input.function_call_id)
with container_io_manager.handle_input_exception(local_input.input_id, started_at):
logger.debug(f"Starting input {local_input.input_id} (sync)")
res = finalized_function.callable(*local_input.args, **local_input.kwargs)
logger.debug(f"Finished input {local_input.input_id} (sync)")
input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs(
local_inputs, finalized_function.callable
)
reset_context = _set_current_context_ids(input_ids, function_call_ids)
with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__):
logger.debug(f"Starting input {input_ids} (sync)")
res = finalized_function.callable(*args, **kwargs)
logger.debug(f"Finished input {input_ids} (sync)")

# TODO(erikbern): any exception below shouldn't be considered a user exception
if finalized_function.is_generator:
if not inspect.isgenerator(res):
raise InvalidError(f"Generator function returned value of type {type(res)}")
if isinstance(function_call_ids, list):
raise InvalidError(
f"Batch function {finalized_function.callable.__name__} cannot return generators."
)
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
local_input.function_call_id,
function_call_ids,
finalized_function.data_format,
generator_queue,
_future=True, # type: ignore # Synchronicity magic to return a future.
Expand All @@ -406,17 +478,31 @@ def run_input_sync(finalized_function: FinalizedFunction, local_input: LocalInpu
generator_output_task.result() # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
container_io_manager.push_output(
local_input.input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE
input_ids,
started_at,
finalized_function.callable.__name__,
message,
api_pb2.DATA_FORMAT_GENERATOR_DONE,
)
else:
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
raise InvalidError(
f"Sync (non-generator) function return value of type {type(res)}."
" You might need to use @app.function(..., is_generator=True)."
)
container_io_manager.push_output(local_input.input_id, started_at, res, finalized_function.data_format)
container_io_manager.push_output(
input_ids, started_at, finalized_function.callable.__name__, res, finalized_function.data_format
)
reset_context()

def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction:
if isinstance(local_inputs, list):
assert len(local_inputs) > 0
assert len(set([local_input.method_name for local_input in local_inputs])) == 1
return finalized_functions[local_inputs[0].method_name]
else:
return finalized_functions[local_inputs.method_name]

if input_concurrency > 1:
with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool:

Expand All @@ -425,24 +511,28 @@ async def run_concurrent_inputs():
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for local_input in container_io_manager.run_inputs_outputs.aio(input_concurrency):
finalized_function = finalized_functions[local_input.method_name]
async for local_inputs in container_io_manager.run_inputs_outputs.aio(
input_concurrency, batch_max_size, batch_wait_ms
):
finalized_function = _get_finalized_functions(local_inputs)
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
if finalized_function.is_async:
task_context.create_task(run_input_async(finalized_function, local_input))
task_context.create_task(
run_input_async(finalized_function, local_inputs, container_io_manager)
)
else:
# run sync input in thread
thread_pool.submit(run_input_sync, finalized_function, local_input)
thread_pool.submit(run_input_sync, finalized_function, local_inputs, container_io_manager)

user_code_event_loop.run(run_concurrent_inputs())
else:
for local_input in container_io_manager.run_inputs_outputs(input_concurrency):
finalized_function = finalized_functions[local_input.method_name]
for local_inputs in container_io_manager.run_inputs_outputs(input_concurrency, batch_max_size, batch_wait_ms):
finalized_function = _get_finalized_functions(local_inputs)
if finalized_function.is_async:
user_code_event_loop.run(run_input_async(finalized_function, local_input))
user_code_event_loop.run(run_input_async(finalized_function, local_inputs, container_io_manager))
else:
# Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
# during function execution. This is sent to cancel inputs from the user
Expand All @@ -453,7 +543,7 @@ def _cancel_input_signal_handler(signum, stackframe):
# run this sync code in the main thread, blocking the "userland" event loop
# this lets us cancel it using a signal handler that raises an exception
try:
run_input_sync(finalized_function, local_input)
run_input_sync(finalized_function, local_inputs, container_io_manager)
finally:
signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler

Expand Down Expand Up @@ -738,10 +828,14 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency doesn't apply for `modal shell`.
# Concurrency and batching doesn't apply for `modal shell`.
input_concurrency = 1
batch_max_size = 0
batch_wait_ms = 0
else:
input_concurrency = function_def.allow_concurrent_inputs or 1
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_wait_ms or 0

# Get ids and metadata for objects (primarily functions and classes) on the app
container_app: RunningApp = container_io_manager.get_app_objects()
Expand Down Expand Up @@ -803,7 +897,14 @@ def breakpoint_wrapper():

# Execute the function.
try:
call_function(event_loop, container_io_manager, finalized_functions, input_concurrency)
call_function(
event_loop,
container_io_manager,
finalized_functions,
input_concurrency,
batch_max_size,
batch_wait_ms,
)
finally:
# Run exit handlers. From this point onward, ignore all SIGINT signals that come from
# graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
Expand Down
Loading
Loading