Skip to content

Commit

Permalink
Support forward references in overload.
Browse files Browse the repository at this point in the history
Refs #89 #issuecomment-1445842092.
  • Loading branch information
coady committed Mar 31, 2023
1 parent 9cfdccd commit ae405ec
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ dev
* Python >=3.8 required
* `Type[...]` dispatches on class arguments
* `|` syntax for union types
* `overload` supports generics
* `overload` supports generics and forward references

1.9.1

Expand Down
22 changes: 16 additions & 6 deletions multimethod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,7 @@ def __call__(self, *args, **kwargs):
raise DispatchError(f"Function {func.__code__}") from ex

def evaluate(self):
"""Evaluate any pending forward references.
This can be called explicitly when using forward references,
otherwise cache misses will evaluate.
"""
"""Evaluate any pending forward references."""
while self.pending:
func = self.pending.pop()
self[get_types(func)] = func
Expand Down Expand Up @@ -388,21 +384,35 @@ def isa(*types: type) -> Callable:
class overload(dict):
"""Ordered functions which dispatch based on their annotated predicates."""

pending: set
__get__ = multimethod.__get__

def __new__(cls, func):
namespace = inspect.currentframe().f_back.f_locals
self = functools.update_wrapper(super().__new__(cls), func)
self.pending = set()
return namespace.get(func.__name__, self)

def __init__(self, func: Callable):
try:
sig = self.signature(func)
except (NameError, AttributeError):
self.pending.add(func)
else:
self[sig] = func

@classmethod
def signature(cls, func: Callable) -> inspect.Signature:
for name, value in get_type_hints(func).items():
if not callable(value) or isinstance(value, type) or hasattr(value, '__origin__'):
func.__annotations__[name] = isa(value)
self[inspect.signature(func)] = func
return inspect.signature(func)

def __call__(self, *args, **kwargs):
"""Dispatch to first matching function."""
while self.pending:
func = self.pending.pop()
self[self.signature(func)] = func
for sig in reversed(self):
try:
arguments = sig.bind(*args, **kwargs).arguments
Expand Down
10 changes: 10 additions & 0 deletions tests/test_overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,13 @@ def func(arg: Optional[str]):
assert func(None) is None
with pytest.raises(DispatchError):
func(0)


class cls:
@overload
def method(self: 'cls'):
return type(self)


def test_annotations():
assert cls().method() is cls

0 comments on commit ae405ec

Please sign in to comment.