Skip to content

Commit

Permalink
update thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Aug 27, 2024
1 parent 986e5aa commit 38b4ac5
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
webhook_asgi_app,
wsgi_app_wrapper,
)
from ._container_io_manager import ContainerIOManager, FinalizedFunction, IOContext, UserException, _ContainerIOManager
from ._container_io_manager import (
ConcurrencyManager,
ContainerIOManager,
FinalizedFunction,
IOContext,
UserException,
_ContainerIOManager,
)
from ._proxy_tunnel import proxy_tunnel
from ._serialization import deserialize, deserialize_proto_params
from ._utils.async_utils import TaskContext, synchronizer
Expand Down Expand Up @@ -211,6 +218,9 @@ class DaemonizedThreadPool:
# Used instead of ThreadPoolExecutor, since the latter won't allow
# the interpreter to shut down before the currently running tasks
# have finished
def __init__(self, concurrency_manager: ConcurrencyManager):
self.concurrency_manager = concurrency_manager

def __enter__(self):
self.spawned_workers = 0
self.inputs: queue.Queue[Any] = queue.Queue()
Expand Down Expand Up @@ -243,8 +253,9 @@ def worker_thread():
logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
self.inputs.task_done()

threading.Thread(target=worker_thread, daemon=True).start()
self.spawned_workers += 1
if self.spawned_workers < self.concurrency_manager.get_concurrency():
threading.Thread(target=worker_thread, daemon=True).start()
self.spawned_workers += 1

self.inputs.put((func, args))

Expand Down Expand Up @@ -413,7 +424,7 @@ def run_input_sync(io_context: IOContext) -> None:
reset_context()

if target_input_concurrency > 1:
with DaemonizedThreadPool() as thread_pool:
with DaemonizedThreadPool(container_io_manager._concurrency_manager) as thread_pool:

def make_async_cancel_callback(task):
def f():
Expand Down Expand Up @@ -730,7 +741,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
batch_wait_ms = 0
else:
target_concurrency = function_def.allow_concurrent_inputs or 1
max_concurrency = 0 # TODO(cathy) add this with interface
max_concurrency = 0
batch_max_size = function_def.batch_max_size or 0
batch_wait_ms = function_def.batch_linger_ms or 0

Expand Down

0 comments on commit 38b4ac5

Please sign in to comment.