Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/multidispatch on kwargs #89

Closed

Conversation

svaningelgem
Copy link

Hi @coady,

I had an issue at work that keyword arguments were not properly matched, so I was diving into the matching engine for the full signature. As multimethod only matches on positional arguments, you hinted that I should use multidispatch, but also that one fell short of what I needed.

Hence I tried to look at other libraries, but seems not much in the market for this kind of routing.
I found one called overload, which was nearly what I needed it to do. So I took the tests and fixed the code so these tests could run through the new multidispatch code.

BTW, the only class I touched is multidispatch. (and I moved some multimethod tests to the test_method file as they belong there imho.

I want to raise this PR to get your idea about what I did and to get some pointers on how to improve on it.

One thing I'm not very certain about is what I also mentioned in the README:

from multimethod import multidispatch


@multidispatch
def func(a, b):
    return 1

@func.register
def _(a, b: float = 1.0, c: str = "A"):
    return 2


print(func(1, 2))  # 1
print(func(1, "A"))  # 1

The last bit happens because "A" is matched with b, which should be a float.
This is because I'm using the Signature.bind method, but that obviously doesn't take into account the type hintings...

If a full blown matching engine needs to be implemented that might lead too far and will slow things down considerably.
So I don't know if it's truly worth the effort to do this now.

At least the basics work: match ANY signature, with ANY kind of arguments.

Let's discuss :)

@coady
Copy link
Owner

coady commented Feb 24, 2023

I think overload - which is already builtin - is what you're looking for. It does a linear scan of the function signatures.

@overload
def func(a, b):
    return 1

@func.register
def _(a, b: float = 1.0, c: str = "A"):
    return 2

@svaningelgem
Copy link
Author

Yes and no...

I tried this by changing all my tests in the test_dispatch.py file to use overload:

def test_keywords():
    class cls: pass

    @overload
    def func(arg):
        return 0

    @overload
    def func(arg: int):
        return 1

    @overload
    def func(arg: int, extra: Union[int, float]):
        return 2

    @overload
    def func(arg: int, extra: str):
        return 3

    @overload
    def func(arg: int, *, extra: cls):
        return 4

    assert func("sth") == 0
    assert func(0) == func(arg=0) == 1
    assert func(0, 0.0) == func(arg=0, extra=0.0) == func(arg=0, extra=0.0) == 2
    assert func(0, 0) == func(0, extra=0) == func(arg=0, extra=0) == 2
    assert func(0, '') == func(0, extra='') == func(arg=0, extra='') == 3
    assert func(0, extra=cls()) == func(arg=0, extra=cls()) == 4

    with pytest.raises(DispatchError):
        func(0, cls())

Fails with

self = typing.Union, args = (0.0,), kwds = {}

    def __call__(self, *args, **kwds):
>       raise TypeError(f"Cannot instantiate {self!r}")
E       TypeError: Cannot instantiate typing.Union

It seems to largely work, except when varargs come into the picture.

Like:

def test_var_positional():
    """Check that we can overload instance methods with variable positional arguments."""

    class cls:
        @overload
        def func(self):
            return 1

        @overload()
        def func(self, *args: object):
            return 2

    assert cls().func() == 1
    assert cls().func(1) == 2

This always defers to the varargs one.

Or when defaults play:

def test_different_signatures():
    @overload
    def func(a: int):
        return f'int: {a}'

    @overload
    def func(a: int, b: float = 3.0):
        return f'int_float: {a} / {b}'

    assert func(1) == 'int: 1'

Fails with:

Expected :'int: 1'
Actual   :'int_float: 1 / 3.0'

It seems to largely do what I want it to do, but not entirely.
For example the last bit I pasted here: it seems to take the last one, or the one with the varargs... The one that can consume the most, whereas (logically) I would say it should use the first one that is capable of matching the signature.

@svaningelgem
Copy link
Author

Ok, changing this:

        for sig in reversed(self):
==>
        for sig in self:

Solved already a few of the test cases.

@svaningelgem
Copy link
Author

svaningelgem commented Feb 25, 2023

    def __init__(self, func: Callable):
        for name, value in get_type_hints(func).items():
            if getattr(value, '__origin__', None) is Union:
                func.__annotations__[name] = isa(value.__args__)
            elif not callable(value) or isinstance(value, type):
                func.__annotations__[name] = isa(value)

        self[inspect.signature(func)] = func

    def _check(self, param, value, sub: bool = False):
        if not sub:
            if param.kind == inspect.Parameter.VAR_POSITIONAL:
                return all(self._check(param, v, sub=True) for v in value)

            if param.kind == inspect.Parameter.VAR_KEYWORD:
                return all(self._check(param, v, sub=True) for v in value.values())

        return param.annotation is param.empty or param.annotation(value)

    def __call__(self, *args, **kwargs):
        """Dispatch to first matching function."""
        for sig in reversed(self):
            try:
                arguments = sig.bind(*args, **kwargs).arguments
            except TypeError:
                continue
            if all(
                self._check(param, arguments[name])
                for name, param in sig.parameters.items()
                if name in arguments
            ):
                return self[sig](*args, **kwargs)
        raise DispatchError("No matching functions found")

    @tp_overload
    def register(self, *args: type) -> Callable: ...

    @tp_overload
    def register(self, func: Callable) -> Callable: ...

    def register(self, *args) -> Callable:
        """Decorator for registering a function."""
        if len(args) == 1 and hasattr(args[0], '__annotations__'):
            func = args[0]
            self.__init__(func)
            return self if self.__name__ == func.__name__ else func  # type: ignore
        return lambda func: self.__setitem__(args, func) or func

This solves:

  1. registration with var-args
  2. union in type hinting
  3. taking the first matching method instead of the last

However, it does not solve this:

    @overload
    def roshambo(left, right):
        return 'tie'

    @roshambo.register(scissors, rock)
    @roshambo.register(rock, scissors)
    def roshambo(left, right):
        return 'rock smashes scissors'

Everything is a tie there.

@svaningelgem
Copy link
Author

Another thing that isn't supported by overload: pending types:

def test_unknown_types():
    class A: pass

    @overload
    def func(a: int):
        return 1

    @func.register
    def func(a: "A"):
        return 2

    assert func(1) == 1
    assert func(A()) == 2

==> E NameError: name 'A' is not defined

@ipcoder
Copy link

ipcoder commented Mar 27, 2023

Hi @svaningelgem
It seems you have added support for several important cases.
Any particular reason why Optional arguments are out of the scope?

@overload
def f(x: int, c: str = None)

coady added a commit that referenced this pull request Mar 31, 2023
Refs #89 #issuecomment-1445842092.
@coady
Copy link
Owner

coady commented Sep 24, 2023

As of v1.10, dispatch supports optionals.

@coady
Copy link
Owner

coady commented Nov 14, 2023

Closing since this seems stalled. Maybe it can be split into separate ideas.

@coady coady closed this Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants