Skip to content

Commit

Permalink
fix error message
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Aug 27, 2024
1 parent 3a72933 commit 8b74df9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
18 changes: 14 additions & 4 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ class DaemonizedThreadPool:
# Used instead of ThreadPoolExecutor, since the latter won't allow
# the interpreter to shut down before the currently running tasks
# have finished
def __init__(self, container_io_manager: "modal._container_io_manager.ContainerIOManager"):
self.container_io_manager = container_io_manager

def __enter__(self):
self.spawned_workers = 0
self.inputs: queue.Queue[Any] = queue.Queue()
Expand Down Expand Up @@ -243,8 +246,9 @@ def worker_thread():
logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
self.inputs.task_done()

threading.Thread(target=worker_thread, daemon=True).start()
self.spawned_workers += 1
if self.spawned_workers < self.concurrency_manager.get_input_concurrency():
threading.Thread(target=worker_thread, daemon=True).start()
self.spawned_workers += 1

self.inputs.put((func, args))

Expand Down Expand Up @@ -413,7 +417,7 @@ def run_input_sync(io_context: IOContext) -> None:
reset_context()

if target_input_concurrency > 1:
with DaemonizedThreadPool() as thread_pool:
with DaemonizedThreadPool(container_io_manager) as thread_pool:

def make_async_cancel_callback(task):
def f():
Expand Down Expand Up @@ -730,7 +734,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
batch_wait_ms = 0
else:
target_concurrency = function_def.allow_concurrent_inputs or 1
max_concurrency = function_def.max_concurrent_inputs or 0
max_concurrency = 0 # TODO(cathy) add this with interface
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0

Expand All @@ -747,6 +751,12 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

# Initialize the function, importing user code.
with container_io_manager.handle_user_exception():
if max_concurrency != 0 and max_concurrency <= target_concurrency:
raise InvalidError("max_concurrent_inputs must be greater than or equal to allow_concurrent_inputs.")
if max_concurrency != 0 and target_concurrency <= 1:
raise InvalidError(
"allow_concurrent_inputs must be greater than 1 to enable automatic input concurrency scaling."
)
if container_args.serialized_params:
param_args, param_kwargs = deserialize_params(container_args.serialized_params, function_def, _client)
else:
Expand Down
21 changes: 6 additions & 15 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .running_app import RunningApp

CONCURRENCY_STATUS_INTERVAL = 3 # seconds
CONCURRENCY_STATUS_TIMEOUT = 5 # seconds
CONCURRENCY_STATUS_TIMEOUT = 10 # seconds
MAX_OUTPUT_BATCH_SIZE: int = 49
RTT_S: float = 0.5 # conservative estimate of RTT in seconds.

Expand Down Expand Up @@ -182,8 +182,8 @@ def validate_output_data(self, data: Any) -> List[Any]:
class ConcurrencyManager(asyncio.Semaphore):
"""Manages the concurrency of inputs for a running container.
The class allows dynamically adjusting the concurrency by adjusting the semaphore value.
It eagerly increases concurrency and lazily decreases concurrency with `_owed_releases`.
The class allows dynamically adjusting the concurrency by changing the semaphore value.
It eagerly increases and lazily decreases concurrency with `_owed_releases`.
"""

_target_concurrency: int
Expand All @@ -206,14 +206,13 @@ def __init__(self, target_concurrency: int, max_concurrency: int) -> None:
self._monitor_thread = None

def initialize(self) -> None:
# Initialize the semaphore with the number of concurrent inputs
# Initialize the semaphore with the current concurrency
assert self._concurrency and not self._initialized
super().__init__(self._concurrency)
self._initialized = True

# Spin up thread to dynamically adjust concurrency given user specified target and max concurrency
# Spin up thread to dynamically adjust concurrency if user specified target and max concurrency
if self._max_concurrency > 1:
print("thread launched!")
self._monitor_thread = threading.Thread(target=self.monitor_concurrency, daemon=True)
self._monitor_thread.start()

Expand Down Expand Up @@ -323,13 +322,6 @@ class _ContainerIOManager:
def _init(
self, container_args: api_pb2.ContainerArguments, client: _Client, target_concurrency: int, max_concurrency: int
) -> None:
with self.handle_user_exception():
if max_concurrency != 0 and max_concurrency <= target_concurrency:
raise InvalidError("max_concurrent_inputs must be greater than or equal to allow_concurrent_inputs.")
if max_concurrency != 0 and target_concurrency <= 1:
raise InvalidError(
"allow_concurrent_inputs must be greater than 1 to enable automatic input concurrency scaling."
)
self.task_id = container_args.task_id
self.function_id = container_args.function_id
self.app_id = container_args.app_id
Expand Down Expand Up @@ -650,8 +642,7 @@ async def run_inputs_outputs(
# Before trying to fetch an input, acquire the semaphore:
# - if no input is fetched, release the semaphore.
# - or, when the output for the fetched input is sent, release the semaphore.
with self.handle_user_exception():
self._concurrency_manager.initialize()
self._concurrency_manager.initialize()

async for inputs in self._generate_inputs(batch_max_size, batch_wait_ms):
io_context = await IOContext.create(self._client, finalized_functions, inputs, batch_max_size > 0)
Expand Down

0 comments on commit 8b74df9

Please sign in to comment.