Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic max_concurrency in client concurrency logic #2158

Merged
merged 18 commits into from
Aug 30, 2024
61 changes: 34 additions & 27 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
def construct_webhook_callable(
user_defined_callable: Callable,
webhook_config: api_pb2.WebhookConfig,
container_io_manager: "modal._container_io_manager.ContainerIOManager",
container_io_manager: ContainerIOManager,
):
# For webhooks, the user function is used to construct an asgi app:
if webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
Expand Down Expand Up @@ -211,8 +211,8 @@ class DaemonizedThreadPool:
# Used instead of ThreadPoolExecutor, since the latter won't allow
# the interpreter to shut down before the currently running tasks
# have finished
def __init__(self, max_threads):
self.max_threads = max_threads
def __init__(self, container_io_manager: "modal._container_io_manager.ContainerIOManager"):
self.container_io_manager = container_io_manager

def __enter__(self):
self.spawned_workers = 0
Expand Down Expand Up @@ -246,7 +246,7 @@ def worker_thread():
logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
self.inputs.task_done()

if self.spawned_workers < self.max_threads:
if self.spawned_workers < self.container_io_manager.get_input_concurrency():
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
threading.Thread(target=worker_thread, daemon=True).start()
self.spawned_workers += 1

Expand Down Expand Up @@ -321,9 +321,9 @@ def call_function(
user_code_event_loop: UserCodeEventLoop,
container_io_manager: "modal._container_io_manager.ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int,
batch_max_size: Optional[int],
batch_wait_ms: Optional[int],
target_input_concurrency: int,
batch_max_size: int,
batch_wait_ms: int,
):
async def run_input_async(io_context: IOContext) -> None:
started_at = time.time()
Expand Down Expand Up @@ -416,8 +416,8 @@ def run_input_sync(io_context: IOContext) -> None:
)
reset_context()

if input_concurrency > 1:
with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool:
if target_input_concurrency > 1:
with DaemonizedThreadPool(container_io_manager) as thread_pool:

def make_async_cancel_callback(task):
def f():
Expand Down Expand Up @@ -448,9 +448,9 @@ async def run_concurrent_inputs():
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for io_context in container_io_manager.run_inputs_outputs.aio(
finalized_functions, input_concurrency, batch_max_size, batch_wait_ms
finalized_functions, batch_max_size, batch_wait_ms
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# Note that run_inputs_outputs will not return until the concurrency manager has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
Expand All @@ -464,9 +464,7 @@ async def run_concurrent_inputs():

user_code_event_loop.run(run_concurrent_inputs())
else:
for io_context in container_io_manager.run_inputs_outputs(
finalized_functions, input_concurrency, batch_max_size, batch_wait_ms
):
for io_context in container_io_manager.run_inputs_outputs(finalized_functions, batch_max_size, batch_wait_ms):
if io_context.finalized_function.is_async:
user_code_event_loop.run(run_input_async(io_context))
else:
Expand Down Expand Up @@ -717,7 +715,6 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
def main(container_args: api_pb2.ContainerArguments, client: Client):
# This is a bit weird but we need both the blocking and async versions of ContainerIOManager.
# At some point, we should fix that by having built-in support for running "user code"
container_io_manager = ContainerIOManager(container_args, client)
active_app: Optional[_App] = None
service: Service
function_def = container_args.function_def
Expand All @@ -728,6 +725,21 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
function_def.is_checkpointing_function and os.environ.get("MODAL_ENABLE_SNAP_RESTORE", "0") == "1"
)

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency and batching doesn't apply for `modal shell`.
target_concurrency = 1
max_concurrency = 0
batch_max_size = 0
batch_wait_ms = 0
else:
target_concurrency = function_def.allow_concurrent_inputs or 1
max_concurrency = function_def.max_concurrent_inputs or 0
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0

container_io_manager = ContainerIOManager(container_args, client, target_concurrency, max_concurrency)

_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly

with container_io_manager.heartbeats(is_snapshotting_function), UserCodeEventLoop() as event_loop:
Expand All @@ -739,6 +751,12 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

# Initialize the function, importing user code.
with container_io_manager.handle_user_exception():
if max_concurrency != 0 and max_concurrency <= target_concurrency:
cathyzbn marked this conversation as resolved.
Show resolved Hide resolved
raise InvalidError("max_concurrent_inputs must be greater than or equal to allow_concurrent_inputs.")
if max_concurrency != 0 and target_concurrency <= 1:
raise InvalidError(
"allow_concurrent_inputs must be greater than 1 to enable automatic input concurrency scaling."
)
if container_args.serialized_params:
param_args, param_kwargs = deserialize_params(container_args.serialized_params, function_def, _client)
else:
Expand Down Expand Up @@ -767,17 +785,6 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
# if the app can't be inferred by the imported function, use name-based fallback
active_app = get_active_app_fallback(function_def)

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency and batching doesn't apply for `modal shell`.
input_concurrency = 1
batch_max_size = 0
batch_wait_ms = 0
else:
input_concurrency = function_def.allow_concurrent_inputs or 1
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0

# Get ids and metadata for objects (primarily functions and classes) on the app
container_app: RunningApp = container_io_manager.get_app_objects()

Expand Down Expand Up @@ -842,7 +849,7 @@ def breakpoint_wrapper():
event_loop,
container_io_manager,
finalized_functions,
input_concurrency,
target_concurrency,
batch_max_size,
batch_wait_ms,
)
Expand Down
Loading
Loading