From dbc6cbda6bbcd041ea4bafa71d0e31997d4e0bf7 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Fri, 12 Jul 2024 12:37:57 -0400 Subject: [PATCH] Raise exception if you put app.function on a class method (#1985) --- modal/_utils/function_utils.py | 19 +++++++++--------- modal/app.py | 36 ++++++++++++++++++++++++++++++++-- test/function_test.py | 12 ++++++++++++ test/function_utils_test.py | 10 ---------- 4 files changed, 55 insertions(+), 22 deletions(-) diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index 87f7d8432..8752e9a40 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -47,10 +47,18 @@ def inner(filename): return inner -def is_global_object(object_qual_name): +def is_global_object(object_qual_name: str): return "" not in object_qual_name.split(".") +def is_top_level_function(f: Callable) -> bool: + """Returns True if this function is defined in global scope. + + Returns False if this function is locally scoped (including on a class). + """ + return f.__name__ == f.__qualname__ + + def is_async(function): # TODO: this is somewhat hacky. We need to know whether the function is async or not in order to # coerce the input arguments to the right type. The proper way to do is to call the function and @@ -104,15 +112,6 @@ def __init__( elif f is None and user_cls: # "service function" for running all methods of a class self.function_name = f"{user_cls.__name__}.*" - elif f.__qualname__ != f.__name__ and not serialized: - # single method of a class - should be only @build-methods at this point - if len(f.__qualname__.split(".")) > 2: - raise InvalidError( - f"Cannot wrap `{f.__qualname__}`:" - " functions and classes used in Modal must be defined in global scope." - " If trying to apply additional decorators, they may need to use `functools.wraps`." - ) - self.function_name = f"{user_cls.__name__}.{f.__name__}" else: self.function_name = f.__qualname__ diff --git a/modal/app.py b/modal/app.py index d304a8a03..359076e12 100644 --- a/modal/app.py +++ b/modal/app.py @@ -4,6 +4,7 @@ import warnings from io import TextIOWrapper from pathlib import PurePosixPath +from textwrap import dedent from typing import Any, AsyncGenerator, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Union from google.protobuf.message import Message @@ -14,7 +15,7 @@ from ._ipython import is_notebook from ._output import OutputManager from ._utils.async_utils import synchronize_api -from ._utils.function_utils import FunctionInfo +from ._utils.function_utils import FunctionInfo, is_global_object, is_top_level_function from ._utils.mount_utils import validate_volumes from .app_utils import ( # noqa: F401 _list_apps, @@ -533,7 +534,9 @@ def wrapped( # Check if the decorated object is a class if inspect.isclass(f): - raise TypeError("The @app.function decorator cannot be used on a class. Please use @app.cls instead.") + raise TypeError( + "The `@app.function` decorator cannot be used on a class. Please use `@app.cls` instead." + ) if isinstance(f, _PartialFunction): # typically for @function-wrapped @web_endpoint and @asgi_app @@ -547,6 +550,35 @@ def wrapped( if webhook_config and interactive: raise InvalidError("interactive=True is not supported with web endpoint functions") else: + if not is_global_object(f.__qualname__) and not serialized: + raise InvalidError( + dedent( + """ + The `@app.function` decorator must apply to functions in global scope, + unless `serialize=True` is set. + If trying to apply additional decorators, they may need to use `functools.wraps`. + """ + ) + ) + + if not is_top_level_function(f) and is_global_object(f.__qualname__): + raise InvalidError( + dedent( + """ + The `@app.function` decorator cannot be used on class methods. + Please use `@app.cls` with `@modal.method` instead. Example: + + ```python + @app.cls() + class MyClass: + @modal.method() + def f(self, x): + ... + ``` + """ + ) + ) + info = FunctionInfo(f, serialized=serialized, name_override=name) webhook_config = None raw_f = f diff --git a/test/function_test.py b/test/function_test.py index 9d202ab61..7e5ff8407 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -790,3 +790,15 @@ def test_warn_on_local_volume_mount(client, servicer): assert modal.is_local() with pytest.warns(match="local"): dummy_function.local() + + +class X: + def f(self): + ... + + +def test_function_decorator_on_method(): + app = modal.App() + + with pytest.raises(InvalidError, match="@app.cls"): + app.function()(X.f) diff --git a/test/function_utils_test.py b/test/function_utils_test.py index 2c9b0a338..0d867866e 100644 --- a/test/function_utils_test.py +++ b/test/function_utils_test.py @@ -1,9 +1,7 @@ # Copyright Modal Labs 2023 -import pytest from modal import method, web_endpoint from modal._utils.function_utils import FunctionInfo, method_has_params -from modal.exception import InvalidError def hasarg(a): @@ -49,14 +47,6 @@ def test_method_has_params(): assert method_has_params(Cls().buz) -def test_nonglobal_function(): - def f(): - ... - - with pytest.raises(InvalidError, match=r"Cannot wrap `test_nonglobal_function..f"): - FunctionInfo(f) - - class Foo: def __init__(self): pass