Skip to content

Commit

Permalink
Support dynamic max_concurrency in client concurrency logic (#2158)
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Aug 30, 2024
1 parent 36a30c2 commit 194053b
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 53 deletions.
39 changes: 16 additions & 23 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class DaemonizedThreadPool:
# Used instead of ThreadPoolExecutor, since the latter won't allow
# the interpreter to shut down before the currently running tasks
# have finished
def __init__(self, max_threads):
def __init__(self, max_threads: int):
self.max_threads = max_threads

def __enter__(self):
Expand Down Expand Up @@ -321,9 +321,8 @@ def call_function(
user_code_event_loop: UserCodeEventLoop,
container_io_manager: "modal._container_io_manager.ContainerIOManager",
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int,
batch_max_size: Optional[int],
batch_wait_ms: Optional[int],
batch_max_size: int,
batch_wait_ms: int,
):
async def run_input_async(io_context: IOContext) -> None:
started_at = time.time()
Expand Down Expand Up @@ -416,8 +415,8 @@ def run_input_sync(io_context: IOContext) -> None:
)
reset_context()

if input_concurrency > 1:
with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool:
if container_io_manager.target_concurrency > 1:
with DaemonizedThreadPool(max_threads=container_io_manager.max_concurrency) as thread_pool:

def make_async_cancel_callback(task):
def f():
Expand Down Expand Up @@ -448,10 +447,10 @@ async def run_concurrent_inputs():
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for io_context in container_io_manager.run_inputs_outputs.aio(
finalized_functions, input_concurrency, batch_max_size, batch_wait_ms
finalized_functions, batch_max_size, batch_wait_ms
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# Note that run_inputs_outputs will not return until all the input slots are released
# so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
if io_context.finalized_function.is_async:
Expand All @@ -464,9 +463,7 @@ async def run_concurrent_inputs():

user_code_event_loop.run(run_concurrent_inputs())
else:
for io_context in container_io_manager.run_inputs_outputs(
finalized_functions, input_concurrency, batch_max_size, batch_wait_ms
):
for io_context in container_io_manager.run_inputs_outputs(finalized_functions, batch_max_size, batch_wait_ms):
if io_context.finalized_function.is_async:
user_code_event_loop.run(run_input_async(io_context))
else:
Expand Down Expand Up @@ -767,16 +764,13 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
# if the app can't be inferred by the imported function, use name-based fallback
active_app = get_active_app_fallback(function_def)

# Container can fetch multiple inputs simultaneously
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency and batching doesn't apply for `modal shell`.
input_concurrency = 1
batch_max_size = 0
batch_wait_ms = 0
else:
input_concurrency = function_def.allow_concurrent_inputs or 1
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0
if function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
# Concurrency and batching doesn't apply for `modal shell`.
batch_max_size = 0
batch_wait_ms = 0
else:
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0

# Get ids and metadata for objects (primarily functions and classes) on the app
container_app: RunningApp = container_io_manager.get_app_objects()
Expand Down Expand Up @@ -842,7 +836,6 @@ def breakpoint_wrapper():
event_loop,
container_io_manager,
finalized_functions,
input_concurrency,
batch_max_size,
batch_wait_ms,
)
Expand Down
161 changes: 134 additions & 27 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import time
import traceback
from contextlib import AsyncExitStack
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncGenerator, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Tuple
Expand All @@ -30,6 +31,8 @@
from .exception import InputCancellation, InvalidError
from .running_app import RunningApp

DYNAMIC_CONCURRENCY_INTERVAL_SECS = 3
DYNAMIC_CONCURRENCY_TIMEOUT_SECS = 10
MAX_OUTPUT_BATCH_SIZE: int = 49

RTT_S: float = 0.5 # conservative estimate of RTT in seconds.
Expand Down Expand Up @@ -177,6 +180,51 @@ def validate_output_data(self, data: Any) -> List[Any]:
return data


class InputSlots:
"""A semaphore that allows dynamically adjusting the concurrency."""

active: int
value: int
waiter: Optional[asyncio.Future]
closed: bool

def __init__(self, value: int) -> None:
self.active = 0
self.value = value
self.waiter = None
self.closed = False

async def acquire(self) -> None:
if self.active < self.value:
self.active += 1
elif self.waiter is None:
self.waiter = asyncio.get_running_loop().create_future()
await self.waiter
else:
raise RuntimeError("Concurrent waiters are not supported.")

def _wake_waiter(self) -> None:
if self.active < self.value and self.waiter is not None:
self.waiter.set_result(None)
self.waiter = None
self.active += 1

def release(self) -> None:
self.active -= 1
self._wake_waiter()

def set_value(self, value: int) -> None:
if self.closed:
return
self.value = value
self._wake_waiter()

async def close(self) -> None:
self.closed = True
for _ in range(self.value):
await self.acquire()


class _ContainerIOManager:
"""Synchronizes all RPC calls and network operations for a running container.
Expand All @@ -196,8 +244,11 @@ class _ContainerIOManager:
current_inputs: Dict[str, IOContext] # input_id -> IOContext
current_input_started_at: Optional[float]

_input_concurrency: Optional[int]
_semaphore: Optional[asyncio.Semaphore]
_target_concurrency: int
_max_concurrency: int
_concurrency_loop: Optional[asyncio.Task]
_input_slots: InputSlots

_environment_name: str
_heartbeat_loop: Optional[asyncio.Task]
_heartbeat_condition: asyncio.Condition
Expand All @@ -224,9 +275,18 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):
self.current_inputs = {}
self.current_input_started_at = None

self._input_concurrency = None
if container_args.function_def.pty_info.pty_type == api_pb2.PTYInfo.PTY_TYPE_SHELL:
target_concurrency = 1
max_concurrency = 0
else:
target_concurrency = container_args.function_def.allow_concurrent_inputs or 1
max_concurrency = container_args.function_def.max_concurrent_inputs or target_concurrency

self._target_concurrency = target_concurrency
self._max_concurrency = max_concurrency
self._concurrency_loop = None
self._input_slots = InputSlots(target_concurrency)

self._semaphore = None
self._environment_name = container_args.environment_name
self._heartbeat_loop = None
self._heartbeat_condition = asyncio.Condition()
Expand Down Expand Up @@ -297,7 +357,7 @@ async def _heartbeat_handle_cancellations(self) -> bool:
# Pause processing of the current input by signaling self a SIGUSR1.
input_ids_to_cancel = response.cancel_input_event.input_ids
if input_ids_to_cancel:
if self._input_concurrency > 1:
if self._target_concurrency > 1:
for input_id in input_ids_to_cancel:
if input_id in self.current_inputs:
self.current_inputs[input_id].cancel()
Expand Down Expand Up @@ -330,6 +390,39 @@ def stop_heartbeat(self):
if self._heartbeat_loop:
self._heartbeat_loop.cancel()

@asynccontextmanager
async def dynamic_concurrency_manager(self) -> AsyncGenerator[None, None]:
async with TaskContext() as tc:
self._concurrency_loop = t = tc.create_task(self._dynamic_concurrency_loop())
t.set_name("dynamic concurrency loop")
try:
yield
finally:
t.cancel()

async def _dynamic_concurrency_loop(self):
logger.debug(f"Starting dynamic concurrency loop for task {self.task_id}")
while 1:
try:
request = api_pb2.FunctionGetDynamicConcurrencyRequest(
function_id=self.function_id,
target_concurrency=self._target_concurrency,
max_concurrency=self._max_concurrency,
)
resp = await retry_transient_errors(
self._client.stub.FunctionGetDynamicConcurrency,
request,
attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
)
if resp.concurrency != self._input_slots.value:
logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
self._input_slots.set_value(resp.concurrency)

except Exception as exc:
logger.debug(f"Failed to get dynamic concurrency for task {self.task_id}, {exc}")

await asyncio.sleep(DYNAMIC_CONCURRENCY_INTERVAL_SECS)

async def get_app_objects(self) -> RunningApp:
req = api_pb2.AppGetObjectsRequest(app_id=self.app_id, include_unindexed=True)
resp = await retry_transient_errors(self._client.stub.AppGetObjects, req)
Expand Down Expand Up @@ -470,12 +563,13 @@ async def _generate_inputs(
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
iteration = 0
while self._fetching_inputs:
await self._input_slots.acquire()

request.average_call_time = self.get_average_call_time()
request.max_values = self.get_max_inputs_to_fetch() # Deprecated; remove.
request.input_concurrency = self._input_concurrency
request.input_concurrency = self.get_input_concurrency()
request.batch_max_size, request.batch_linger_ms = batch_max_size, batch_wait_ms

await self._semaphore.acquire()
yielded = False
try:
# If number of active inputs is at max queue size, this will block.
Expand Down Expand Up @@ -508,7 +602,7 @@ async def _generate_inputs(
final_input_received = True
break

# If yielded, allow semaphore to be released via exit_context
# If yielded, allow input slots to be released via exit_context
yield inputs
yielded = True

Expand All @@ -517,35 +611,34 @@ async def _generate_inputs(
return
finally:
if not yielded:
self._semaphore.release()
self._input_slots.release()

@synchronizer.no_io_translation
async def run_inputs_outputs(
self,
finalized_functions: Dict[str, FinalizedFunction],
input_concurrency: int = 1,
batch_max_size: int = 0,
batch_wait_ms: int = 0,
) -> AsyncIterator[IOContext]:
# Ensure we do not fetch new inputs when container is too busy.
# Before trying to fetch an input, acquire the semaphore:
# - if no input is fetched, release the semaphore.
# - or, when the output for the fetched input is sent, release the semaphore.
self._input_concurrency = input_concurrency
self._semaphore = asyncio.Semaphore(input_concurrency)

async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
for input_id in io_context.input_ids:
self.current_inputs[input_id] = io_context
# Before trying to fetch an input, acquire an input slot:
# - if no input is fetched, release the input slot.
# - or, when the output for the fetched input is sent, release the input slot.
dynamic_concurrency_manager = (
self.dynamic_concurrency_manager() if self._max_concurrency > self._target_concurrency else AsyncExitStack()
)
async with dynamic_concurrency_manager:
async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
for input_id in io_context.input_ids:
self.current_inputs[input_id] = io_context

self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
yield io_context
self.current_input_id, self.current_input_started_at = (None, None)
self.current_input_id, self.current_input_started_at = io_context.input_ids[0], time.time()
yield io_context
self.current_input_id, self.current_input_started_at = (None, None)

# collect all active input slots, meaning all inputs have wrapped up.
for _ in range(input_concurrency):
await self._semaphore.acquire()
# collect all active input slots, meaning all inputs have wrapped up.
await self._input_slots.close()

@synchronizer.no_io_translation
async def _push_outputs(
Expand Down Expand Up @@ -692,7 +785,7 @@ async def exit_context(self, started_at, input_ids: List[str]):
for input_id in input_ids:
self.current_inputs.pop(input_id)

self._semaphore.release()
self._input_slots.release()

@synchronizer.no_io_translation
async def push_outputs(self, io_context: IOContext, started_at: float, data: Any, data_format: int) -> None:
Expand Down Expand Up @@ -840,6 +933,20 @@ async def interact(self, from_breakpoint: bool = False):
print("Error: Failed to start PTY shell.")
raise e

@property
def target_concurrency(self) -> int:
return self._target_concurrency

@property
def max_concurrency(self) -> int:
return self._max_concurrency

@classmethod
def get_input_concurrency(cls) -> int:
io_manager = cls._singleton
assert io_manager
return io_manager._input_slots.value

@classmethod
def stop_fetching_inputs(cls):
assert cls._singleton
Expand Down
6 changes: 6 additions & 0 deletions modal/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ def stop_fetching_inputs():
The container will exit gracefully after the current input is processed."""

_ContainerIOManager.stop_fetching_inputs()


def get_local_input_concurrency():
"""Get the container's local input concurrency. Return 0 if the container is not running."""

return _ContainerIOManager.get_input_concurrency()
1 change: 1 addition & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ message Function {
uint64 batch_linger_ms = 61; // Miliseconds to block before a response is needed
bool i6pn_enabled = 62;
bool _experimental_concurrent_cancellations = 63;
uint32 max_concurrent_inputs = 64;
}

message FunctionBindParamsRequest {
Expand Down
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,9 @@ def output_lockstep(self) -> Iterator[threading.Barrier]:
yield self.put_outputs_barrier
self.put_outputs_barrier = threading.Barrier(1)

async def FunctionGetDynamicConcurrency(self, stream):
await stream.send_message(api_pb2.FunctionGetDynamicConcurrencyResponse(concurrency=5))

async def FunctionGetInputs(self, stream):
await asyncio.get_running_loop().run_in_executor(None, self.get_inputs_barrier.wait)
request: api_pb2.FunctionGetInputsRequest = await stream.recv_message()
Expand Down
Loading

0 comments on commit 194053b

Please sign in to comment.