Skip to content

Commit

Permalink
Raise exception if you put app.function on a class method (#1985)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Jul 12, 2024
1 parent ded0d09 commit dbc6cbd
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 22 deletions.
19 changes: 9 additions & 10 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,18 @@ def inner(filename):
return inner


def is_global_object(object_qual_name):
def is_global_object(object_qual_name: str):
return "<locals>" not in object_qual_name.split(".")


def is_top_level_function(f: Callable) -> bool:
"""Returns True if this function is defined in global scope.
Returns False if this function is locally scoped (including on a class).
"""
return f.__name__ == f.__qualname__


def is_async(function):
# TODO: this is somewhat hacky. We need to know whether the function is async or not in order to
# coerce the input arguments to the right type. The proper way to do is to call the function and
Expand Down Expand Up @@ -104,15 +112,6 @@ def __init__(
elif f is None and user_cls:
# "service function" for running all methods of a class
self.function_name = f"{user_cls.__name__}.*"
elif f.__qualname__ != f.__name__ and not serialized:
# single method of a class - should be only @build-methods at this point
if len(f.__qualname__.split(".")) > 2:
raise InvalidError(
f"Cannot wrap `{f.__qualname__}`:"
" functions and classes used in Modal must be defined in global scope."
" If trying to apply additional decorators, they may need to use `functools.wraps`."
)
self.function_name = f"{user_cls.__name__}.{f.__name__}"
else:
self.function_name = f.__qualname__

Expand Down
36 changes: 34 additions & 2 deletions modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from io import TextIOWrapper
from pathlib import PurePosixPath
from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union

from google.protobuf.message import Message
Expand All @@ -14,7 +15,7 @@
from ._ipython import is_notebook
from ._output import OutputManager
from ._utils.async_utils import synchronize_api
from ._utils.function_utils import FunctionInfo
from ._utils.function_utils import FunctionInfo, is_global_object, is_top_level_function
from ._utils.mount_utils import validate_volumes
from .app_utils import ( # noqa: F401
_list_apps,
Expand Down Expand Up @@ -533,7 +534,9 @@ def wrapped(

# Check if the decorated object is a class
if inspect.isclass(f):
raise TypeError("The @app.function decorator cannot be used on a class. Please use @app.cls instead.")
raise TypeError(
"The `@app.function` decorator cannot be used on a class. Please use `@app.cls` instead."
)

if isinstance(f, _PartialFunction):
# typically for @function-wrapped @web_endpoint and @asgi_app
Expand All @@ -547,6 +550,35 @@ def wrapped(
if webhook_config and interactive:
raise InvalidError("interactive=True is not supported with web endpoint functions")
else:
if not is_global_object(f.__qualname__) and not serialized:
raise InvalidError(
dedent(
"""
The `@app.function` decorator must apply to functions in global scope,
unless `serialize=True` is set.
If trying to apply additional decorators, they may need to use `functools.wraps`.
"""
)
)

if not is_top_level_function(f) and is_global_object(f.__qualname__):
raise InvalidError(
dedent(
"""
The `@app.function` decorator cannot be used on class methods.
Please use `@app.cls` with `@modal.method` instead. Example:
```python
@app.cls()
class MyClass:
@modal.method()
def f(self, x):
...
```
"""
)
)

info = FunctionInfo(f, serialized=serialized, name_override=name)
webhook_config = None
raw_f = f
Expand Down
12 changes: 12 additions & 0 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,15 @@ def test_warn_on_local_volume_mount(client, servicer):
assert modal.is_local()
with pytest.warns(match="local"):
dummy_function.local()


class X:
def f(self):
...


def test_function_decorator_on_method():
app = modal.App()

with pytest.raises(InvalidError, match="@app.cls"):
app.function()(X.f)
10 changes: 0 additions & 10 deletions test/function_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright Modal Labs 2023
import pytest

from modal import method, web_endpoint
from modal._utils.function_utils import FunctionInfo, method_has_params
from modal.exception import InvalidError


def hasarg(a):
Expand Down Expand Up @@ -49,14 +47,6 @@ def test_method_has_params():
assert method_has_params(Cls().buz)


def test_nonglobal_function():
def f():
...

with pytest.raises(InvalidError, match=r"Cannot wrap `test_nonglobal_function.<locals>.f"):
FunctionInfo(f)


class Foo:
def __init__(self):
pass
Expand Down

0 comments on commit dbc6cbd

Please sign in to comment.