From ae405eca03d8c5384ec15750fd264f52697d03d2 Mon Sep 17 00:00:00 2001 From: Aric Coady Date: Thu, 30 Mar 2023 17:20:40 -0700 Subject: [PATCH] Support forward references in `overload`. Refs #89 #issuecomment-1445842092. --- README.md | 2 +- multimethod/__init__.py | 22 ++++++++++++++++------ tests/test_overload.py | 10 ++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4d4de36..5f7bfdd 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ dev * Python >=3.8 required * `Type[...]` dispatches on class arguments * `|` syntax for union types -* `overload` supports generics +* `overload` supports generics and forward references 1.9.1 diff --git a/multimethod/__init__.py b/multimethod/__init__.py index eb077c7..ff341de 100644 --- a/multimethod/__init__.py +++ b/multimethod/__init__.py @@ -325,11 +325,7 @@ def __call__(self, *args, **kwargs): raise DispatchError(f"Function {func.__code__}") from ex def evaluate(self): - """Evaluate any pending forward references. - - This can be called explicitly when using forward references, - otherwise cache misses will evaluate. - """ + """Evaluate any pending forward references.""" while self.pending: func = self.pending.pop() self[get_types(func)] = func @@ -388,21 +384,35 @@ def isa(*types: type) -> Callable: class overload(dict): """Ordered functions which dispatch based on their annotated predicates.""" + pending: set __get__ = multimethod.__get__ def __new__(cls, func): namespace = inspect.currentframe().f_back.f_locals self = functools.update_wrapper(super().__new__(cls), func) + self.pending = set() return namespace.get(func.__name__, self) def __init__(self, func: Callable): + try: + sig = self.signature(func) + except (NameError, AttributeError): + self.pending.add(func) + else: + self[sig] = func + + @classmethod + def signature(cls, func: Callable) -> inspect.Signature: for name, value in get_type_hints(func).items(): if not callable(value) or isinstance(value, type) or hasattr(value, '__origin__'): func.__annotations__[name] = isa(value) - self[inspect.signature(func)] = func + return inspect.signature(func) def __call__(self, *args, **kwargs): """Dispatch to first matching function.""" + while self.pending: + func = self.pending.pop() + self[self.signature(func)] = func for sig in reversed(self): try: arguments = sig.bind(*args, **kwargs).arguments diff --git a/tests/test_overload.py b/tests/test_overload.py index 0f86581..9985b1d 100644 --- a/tests/test_overload.py +++ b/tests/test_overload.py @@ -54,3 +54,13 @@ def func(arg: Optional[str]): assert func(None) is None with pytest.raises(DispatchError): func(0) + + +class cls: + @overload + def method(self: 'cls'): + return type(self) + + +def test_annotations(): + assert cls().method() is cls