Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching #2051

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .image import Image
from .mount import Mount
from .network_file_system import NetworkFileSystem
from .partial_function import asgi_app, build, enter, exit, method, web_endpoint, web_server, wsgi_app
from .partial_function import asgi_app, batch, build, enter, exit, method, web_endpoint, web_server, wsgi_app
from .proxy import Proxy
from .queue import Queue
from .retries import Retries
Expand Down Expand Up @@ -64,6 +64,7 @@
"Tunnel",
"Volume",
"asgi_app",
"batch",
"build",
"current_function_call_id",
"current_input_id",
Expand Down
32 changes: 30 additions & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from .mount import _Mount
from .network_file_system import _NetworkFileSystem
from .object import _Object
from .partial_function import _find_callables_for_cls, _PartialFunction, _PartialFunctionFlags
from .partial_function import (
_find_callables_for_cls,
_find_partial_methods_for_user_cls,
_PartialFunction,
_PartialFunctionFlags,
)
from .proxy import _Proxy
from .retries import Retries
from .runner import _run_app
Expand Down Expand Up @@ -589,13 +594,15 @@ def wrapped(
)

if isinstance(f, _PartialFunction):
# typically for @function-wrapped @web_endpoint and @asgi_app
# typically for @function-wrapped @web_endpoint, @asgi_app, or @batch
f.wrapped = True
info = FunctionInfo(f.raw_f, serialized=serialized, name_override=name)
raw_f = f.raw_f
webhook_config = f.webhook_config
is_generator = f.is_generator
keep_warm = f.keep_warm or keep_warm
batch_max_size = f.batch_max_size
batch_linger_ms = f.batch_linger_ms

if webhook_config and interactive:
raise InvalidError("interactive=True is not supported with web endpoint functions")
Expand Down Expand Up @@ -631,6 +638,8 @@ def f(self, x):

info = FunctionInfo(f, serialized=serialized, name_override=name)
webhook_config = None
batch_max_size = None
batch_linger_ms = None
raw_f = f

if info.function_name.endswith(".app"):
Expand Down Expand Up @@ -668,6 +677,8 @@ def f(self, x):
retries=retries,
concurrency_limit=concurrency_limit,
allow_concurrent_inputs=allow_concurrent_inputs,
batch_max_size=batch_max_size,
batch_linger_ms=batch_linger_ms,
container_idle_timeout=container_idle_timeout,
timeout=timeout,
keep_warm=keep_warm,
Expand Down Expand Up @@ -768,6 +779,21 @@ def wrapper(user_cls: CLS_T) -> _Cls:
raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together")
scheduler_placement = SchedulerPlacement(region=region)

batch_functions = _find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.BATCH)
if batch_functions:
if len(batch_functions) > 1:
raise InvalidError(f"Modal class {user_cls.__name__} can only have one batch function.")
if len(_find_partial_methods_for_user_cls(user_cls, _PartialFunctionFlags.FUNCTION)) > 1:
raise InvalidError(
f"Modal class {user_cls.__name__} with a modal batch function cannot have other modal methods."
)
batch_function = next(iter(batch_functions.values()))
batch_max_size = batch_function.batch_max_size
batch_linger_ms = batch_function.batch_linger_ms
else:
batch_max_size = None
batch_linger_ms = None

cls_func = _Function.from_args(
info,
app=self,
Expand All @@ -785,6 +811,8 @@ def wrapper(user_cls: CLS_T) -> _Cls:
retries=retries,
concurrency_limit=concurrency_limit,
allow_concurrent_inputs=allow_concurrent_inputs,
batch_max_size=batch_max_size,
batch_linger_ms=batch_linger_ms,
container_idle_timeout=container_idle_timeout,
timeout=timeout,
cpu=cpu,
Expand Down
8 changes: 8 additions & 0 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing
is_method=True,
use_function_id=class_service_function.object_id,
use_method_name=method_name,
batch_max_size=partial_function.batch_max_size or 0,
batch_linger_ms=partial_function.batch_linger_ms or 0,
)
assert resolver.app_id
request = api_pb2.FunctionCreateRequest(
Expand Down Expand Up @@ -484,6 +486,8 @@ def from_args(
timeout: Optional[int] = None,
concurrency_limit: Optional[int] = None,
allow_concurrent_inputs: Optional[int] = None,
batch_max_size: Optional[int] = None,
batch_linger_ms: Optional[int] = None,
container_idle_timeout: Optional[int] = None,
cpu: Optional[float] = None,
keep_warm: Optional[int] = None, # keep_warm=True is equivalent to keep_warm=1
Expand Down Expand Up @@ -652,6 +656,8 @@ def from_args(
)
else:
raise InvalidError("Webhooks cannot be generators")
if is_generator and batch_max_size:
raise InvalidError("Batch functions cannot return generators")

if container_idle_timeout is not None and container_idle_timeout <= 0:
raise InvalidError("`container_idle_timeout` must be > 0")
Expand Down Expand Up @@ -808,6 +814,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
app_name=app_name,
is_builder_function=is_builder_function,
allow_concurrent_inputs=allow_concurrent_inputs or 0,
batch_max_size=batch_max_size or 0,
batch_linger_ms=batch_linger_ms or 0,
worker_id=config.get("worker_id"),
is_auto_snapshot=is_auto_snapshot,
is_method=bool(info.user_cls) and not info.is_service_class(),
Expand Down
57 changes: 55 additions & 2 deletions modal/partial_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,33 @@
from .exception import InvalidError, deprecation_error, deprecation_warning
from .functions import _Function

MAX_BATCH_SIZE = 1000
MAX_BATCH_LINGER_MS = 10 * 60 * 1000 # 10 minutes


class _PartialFunctionFlags(enum.IntFlag):
FUNCTION: int = 1
BUILD: int = 2
ENTER_PRE_SNAPSHOT: int = 4
ENTER_POST_SNAPSHOT: int = 8
EXIT: int = 16
BATCH: int = 32

@staticmethod
def all() -> "_PartialFunctionFlags":
return ~_PartialFunctionFlags(0) # type: ignore # for some reason mypy things this has type int


class _PartialFunction:
"""Intermediate function, produced by @method or @web_endpoint"""
"""Intermediate function, produced by @method, @web_endpoint, or @batch"""

raw_f: Callable[..., Any]
flags: _PartialFunctionFlags
webhook_config: Optional[api_pb2.WebhookConfig]
is_generator: Optional[bool]
keep_warm: Optional[int]
batch_max_size: Optional[int]
batch_linger_ms: Optional[int]

def __init__(
self,
Expand All @@ -49,13 +55,17 @@ def __init__(
webhook_config: Optional[api_pb2.WebhookConfig] = None,
is_generator: Optional[bool] = None,
keep_warm: Optional[int] = None,
batch_max_size: Optional[int] = None,
batch_linger_ms: Optional[int] = None,
):
self.raw_f = raw_f
self.flags = flags
self.webhook_config = webhook_config
self.is_generator = is_generator
self.keep_warm = keep_warm
self.wrapped = False # Make sure that this was converted into a FunctionHandle
self.batch_max_size = batch_max_size
self.batch_linger_ms = batch_linger_ms

def __get__(self, obj, objtype=None) -> _Function:
k = self.raw_f.__name__
Expand Down Expand Up @@ -93,6 +103,8 @@ def add_flags(self, flags) -> "_PartialFunction":
flags=(self.flags | flags),
webhook_config=self.webhook_config,
keep_warm=self.keep_warm,
batch_max_size=self.batch_max_size,
batch_linger_ms=self.batch_linger_ms,
)


Expand Down Expand Up @@ -190,12 +202,18 @@ def f(self):

def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
nonlocal is_generator
if isinstance(raw_f, _PartialFunction) and raw_f.webhook_config:
if isinstance(raw_f, _PartialFunction) and (raw_f.webhook_config):
raw_f.wrapped = True # suppress later warning
raise InvalidError(
"Web endpoints on classes should not be wrapped by `@method`. "
"Suggestion: remove the `@method` decorator."
)
if isinstance(raw_f, _PartialFunction) and (raw_f.batch_max_size is not None):
raw_f.wrapped = True # suppress later warning
raise InvalidError(
"Batch function on classes should not be wrapped by `@method`. "
"Suggestion: remove the `@method` decorator."
)
if is_generator is None:
is_generator = inspect.isgeneratorfunction(raw_f) or inspect.isasyncgenfunction(raw_f)
return _PartialFunction(raw_f, _PartialFunctionFlags.FUNCTION, is_generator=is_generator, keep_warm=keep_warm)
Expand Down Expand Up @@ -554,6 +572,40 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction:
return wrapper


def _batch(
_warn_parentheses_missing=None,
*,
batch_max_size: int,
batch_linger_ms: int,
) -> Callable[[Callable[..., Any]], _PartialFunction]:
if _warn_parentheses_missing:
raise InvalidError("Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@batch()`.")
if batch_max_size < 1:
raise InvalidError("batch_max_size must be a positive integer.")
if batch_max_size >= MAX_BATCH_SIZE:
raise InvalidError(f"batch_max_size must be less than {MAX_BATCH_SIZE}.")
if batch_linger_ms < 0:
raise InvalidError("batch_linger_ms must be a non-negative integer.")
if batch_linger_ms >= MAX_BATCH_LINGER_MS:
raise InvalidError(f"batch_linger_ms must be less than {MAX_BATCH_LINGER_MS}.")

def wrapper(raw_f: Union[Callable[[Any], Any], _PartialFunction]) -> _PartialFunction:
if isinstance(raw_f, _Function):
raw_f = raw_f.get_raw_f()
raise InvalidError(
f"Applying decorators for {raw_f} in the wrong order!\nUsage:\n\n"
"@app.function()\n@modal.batch()\ndef batched_function():\n ..."
)
return _PartialFunction(
raw_f,
_PartialFunctionFlags.FUNCTION | _PartialFunctionFlags.BATCH,
batch_max_size=batch_max_size,
batch_linger_ms=batch_linger_ms,
)

return wrapper


method = synchronize_api(_method)
web_endpoint = synchronize_api(_web_endpoint)
asgi_app = synchronize_api(_asgi_app)
Expand All @@ -562,3 +614,4 @@ def wrapper(f: ExitHandlerType) -> _PartialFunction:
build = synchronize_api(_build)
enter = synchronize_api(_enter)
exit = synchronize_api(_exit)
batch = synchronize_api(_batch)
31 changes: 31 additions & 0 deletions test/cls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,3 +867,34 @@ def test_disabled_parameterized_snap_cls():
app.cls(enable_memory_snapshot=True)(ParameterizedClass2)

app.cls(enable_memory_snapshot=True)(ParameterizedClass3)


app_batch = App()


def test_batch_method_duplicate_error(client):
with pytest.raises(
InvalidError, match="Modal class BatchClass_1 with a modal batch function cannot have other modal methods."
):

@app_batch.cls(serialized=True)
class BatchClass_1:
@modal.method()
def method(self):
pass

@modal.batch(batch_max_size=2, batch_linger_ms=0)
def batch_method(self):
pass

with pytest.raises(InvalidError, match="Modal class BatchClass_2 can only have one batch function."):

@app_batch.cls(serialized=True)
class BatchClass_2:
@modal.batch(batch_max_size=2, batch_linger_ms=0)
def batch_method_1(self):
pass

@modal.batch(batch_max_size=2, batch_linger_ms=0)
def batch_method_2(self):
pass
Loading
Loading