Skip to content

Commit

Permalink
reduce duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
cathyzbn committed Aug 1, 2024
1 parent 853c37b commit d2a0f6b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 110 deletions.
71 changes: 45 additions & 26 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import typing
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple

from google.protobuf.message import Message
from synchronicity import Interface
Expand Down Expand Up @@ -324,15 +324,18 @@ def _sigint_handler():


def _aggregate_args_and_kwargs(
local_inputs: Union[LocalInput, List[LocalInput]],
local_inputs: List[LocalInput],
callable: Callable[..., Any],
) -> Tuple[Union[str, List[str]], Union[str, List[str]], Tuple[Any, ...], Dict[str, Any]]:
if isinstance(local_inputs, LocalInput):
return local_inputs.input_id, local_inputs.function_call_id, local_inputs.args, local_inputs.kwargs

# aggregate args and kwargs for batched input
is_batched: bool,
) -> Tuple[List[str], List[str], Tuple[Any, ...], Dict[str, Any]]:
input_ids = [local_input.input_id for local_input in local_inputs]
function_call_ids = [local_input.function_call_id for local_input in local_inputs]

if not is_batched:
assert len(local_inputs) == 1
return input_ids, function_call_ids, local_inputs[0].args, local_inputs[0].kwargs

# aggregate args and kwargs for batched input
param_names = list(inspect.signature(callable).parameters.keys())
for param in inspect.signature(callable).parameters.values():
if param.default is not inspect.Parameter.empty:
Expand Down Expand Up @@ -375,16 +378,17 @@ def call_function(
):
async def run_input_async(
finalized_function: FinalizedFunction,
local_inputs: Union[LocalInput, List[LocalInput]],
local_inputs: List[LocalInput],
container_io_manager: "modal._container_io_manager.ContainerIOManager",
) -> None:
started_at = time.time()
is_batched = container_io_manager.is_batched()
input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs(
local_inputs, finalized_function.callable
local_inputs, finalized_function.callable, is_batched
)
reset_context = _set_current_context_ids(input_ids, function_call_ids)
async with container_io_manager.handle_input_exception.aio(
input_ids, started_at, finalized_function.callable.__name__
input_ids, started_at, finalized_function.callable.__name__, is_batched
):
logger.debug(f"Starting input {input_ids} (async)")
res = finalized_function.callable(*args, **kwargs)
Expand All @@ -394,16 +398,17 @@ async def run_input_async(
if finalized_function.is_generator:
if not inspect.isasyncgen(res):
raise InvalidError(f"Async generator function returned value of type {type(res)}")
if isinstance(input_ids, list) or isinstance(function_call_ids, list):
if is_batched:
raise InvalidError(
f"Batch function {finalized_function.callable.__name__} cannot return generators."
)
assert len(function_call_ids) == 1

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
generator_output_task = asyncio.create_task(
container_io_manager.generator_output_task.aio(
function_call_ids,
function_call_ids[0],
finalized_function.data_format,
generator_queue,
)
Expand All @@ -419,6 +424,7 @@ async def run_input_async(
message = api_pb2.GeneratorDone(items_total=item_count)
await container_io_manager.push_output.aio(
input_ids,
is_batched,
started_at,
finalized_function.callable.__name__,
message,
Expand All @@ -432,21 +438,29 @@ async def run_input_async(
)
value = await res
await container_io_manager.push_output.aio(
input_ids, started_at, finalized_function.callable.__name__, value, finalized_function.data_format
input_ids,
is_batched,
started_at,
finalized_function.callable.__name__,
value,
finalized_function.data_format,
)
reset_context()

def run_input_sync(
finalized_function: FinalizedFunction,
local_inputs: Union[LocalInput, List[LocalInput]],
local_inputs: List[LocalInput],
container_io_manager: "modal._container_io_manager.ContainerIOManager",
) -> None:
started_at = time.time()
is_batched = container_io_manager.is_batched()
input_ids, function_call_ids, args, kwargs = _aggregate_args_and_kwargs(
local_inputs, finalized_function.callable
local_inputs, finalized_function.callable, is_batched
)
reset_context = _set_current_context_ids(input_ids, function_call_ids)
with container_io_manager.handle_input_exception(input_ids, started_at, finalized_function.callable.__name__):
with container_io_manager.handle_input_exception(
input_ids, started_at, finalized_function.callable.__name__, is_batched
):
logger.debug(f"Starting input {input_ids} (sync)")
res = finalized_function.callable(*args, **kwargs)
logger.debug(f"Finished input {input_ids} (sync)")
Expand All @@ -455,15 +469,16 @@ def run_input_sync(
if finalized_function.is_generator:
if not inspect.isgenerator(res):
raise InvalidError(f"Generator function returned value of type {type(res)}")
if isinstance(function_call_ids, list):
if is_batched:
raise InvalidError(
f"Batch function {finalized_function.callable.__name__} cannot return generators."
)
assert len(function_call_ids) == 1

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
generator_output_task: concurrent.futures.Future = container_io_manager.generator_output_task( # type: ignore
function_call_ids,
function_call_ids[0],
finalized_function.data_format,
generator_queue,
_future=True, # type: ignore # Synchronicity magic to return a future.
Expand All @@ -479,6 +494,7 @@ def run_input_sync(
message = api_pb2.GeneratorDone(items_total=item_count)
container_io_manager.push_output(
input_ids,
is_batched,
started_at,
finalized_function.callable.__name__,
message,
Expand All @@ -491,17 +507,20 @@ def run_input_sync(
" You might need to use @app.function(..., is_generator=True)."
)
container_io_manager.push_output(
input_ids, started_at, finalized_function.callable.__name__, res, finalized_function.data_format
input_ids,
is_batched,
started_at,
finalized_function.callable.__name__,
res,
finalized_function.data_format,
)
reset_context()

def _get_finalized_functions(local_inputs: Union[LocalInput, List[LocalInput]]) -> FinalizedFunction:
if isinstance(local_inputs, list):
assert len(local_inputs) > 0
assert len(set([local_input.method_name for local_input in local_inputs])) == 1
return finalized_functions[local_inputs[0].method_name]
else:
return finalized_functions[local_inputs.method_name]
def _get_finalized_functions(local_inputs: List[LocalInput]) -> FinalizedFunction:
assert len(local_inputs) > 0
# all functions in a batch must have the same method name
assert len(set([local_input.method_name for local_input in local_inputs])) == 1
return finalized_functions[local_inputs[0].method_name]

if input_concurrency > 1:
with DaemonizedThreadPool(max_threads=input_concurrency) as thread_pool:
Expand Down
Loading

0 comments on commit d2a0f6b

Please sign in to comment.