Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed Jun 5, 2024
1 parent 513de22 commit 6877d37
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 72 deletions.
58 changes: 30 additions & 28 deletions src/pyoload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from functools import wraps
from inspect import _empty
from inspect import getmodule
from inspect import isclass
from inspect import signature
from types import NoneType
Expand All @@ -15,8 +16,6 @@
from typing import GenericAlias
from typing import Type

import sys


class AnnotationError(ValueError):
"""
Expand Down Expand Up @@ -346,7 +345,9 @@ def cast(val: Any, totype: Any) -> Any:
kt, vt = totype.__args__
elif len(totype.__args__) == 1:
kt, vt = Any, totype.__args__[1]
return {Cast.cast(k, kt): Cast.cast(v, vt) for k, v in val.items()}
return {
Cast.cast(k, kt): Cast.cast(v, vt) for k, v in val.items()
}
else:
sub = totype.__args__[0]
return totype.__origin__([Cast.cast(v, sub) for v in val])
Expand Down Expand Up @@ -504,27 +505,6 @@ def typeMatch(val: Any, spec: Any) -> bool:
return isinstance(val, spec)


def get_module(obj: Any):
"""
gets the module to which an object, function or class belongs
e.g
>>> class foo:
... def bar(self):
... pass
...
>>> get_name(foo)
'__main__.foo'
>>> get_name(foo.bar)
'__main__.foo.bar'
:param obj: the object
:returns: the module
"""
return sys.modules[obj.__module__]


def resolveAnnotations(obj: Type | Callable) -> None:
"""
Evaluates all the stringized annotations of the argument
Expand All @@ -539,7 +519,7 @@ def resolveAnnotations(obj: Type | Callable) -> None:
try:
obj.__annotations__[k] = eval(
v,
dict(vars(get_module(obj))),
dict(vars(getmodule(obj))),
dict(vars(obj)),
)
except Exception as e:
Expand Down Expand Up @@ -585,6 +565,8 @@ def annotate(
return annotateClass(func)
if len(func.__annotations__) == 0:
return func
if not is_annotable(func) and not force:
return func

@wraps(func)
def wrapper(*pargs, **kw):
Expand Down Expand Up @@ -648,6 +630,26 @@ def wrapper(*pargs, **kw):
return wrapper


def unannotate(func: Callable) -> Callable:
if hasattr(func, '__pyod_annotate__'):
return func.__pyod_annotate__
else:
return func


def unannotable(func: Callable) -> Callable:
func = unannotate(func)
func.__pyod_annotable__ = False


def annotable(func: Callable) -> Callable:
func.__pyod_annotable__ = True


def is_annotable(func):
return not hasattr(func, '__pyod_annotable__') or func.__pyod_annotable__


__overloads__: dict[str, list[Callable]] = {}


Expand Down Expand Up @@ -677,7 +679,7 @@ def overload(func: Callable, name: str | None = None) -> Callable:
name = get_name(func)
if name not in __overloads__:
__overloads__[name] = []
__overloads__[name].append(annotate(func, True))
__overloads__[name].append(annotate(func, oload=True))

@wraps(func)
def wrapper(*args, **kw):
Expand Down Expand Up @@ -752,5 +754,5 @@ def new_setter(self: Any, name: str, value: Any) -> Any:
return cls


__version__ = '2.0.0'
__author__ = 'ken-morel'
__version__ = "2.0.0"
__author__ = "ken-morel"
12 changes: 6 additions & 6 deletions src/tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from pyoload import typeMatch
from pyoload import AnnotationError

assert pyoload.__version__ == '2.0.0'
assert pyoload.__version__ == "2.0.0"


@annotate
class foo:
foo = CastedAttr(dict[str, tuple[int | str]])
bar: Cast(list[tuple[float]])
a: 'str'
a: "str"

def __init__(self: 'Any', bar: 'list') -> Any:
def __init__(self: "Any", bar: "list") -> Any:
self.bar = bar
self.a = "ama"
try:
Expand All @@ -28,11 +28,11 @@ def __init__(self: 'Any', bar: 'list') -> Any:


def test_cast():
q = foo([(1, '67')])
q.foo = {1234: {'5', 16j}}
q = foo([(1, "67")])
q.foo = {1234: {"5", 16j}}
assert typeMatch(q.foo, dict[str, tuple[int | str]])
assert typeMatch(q.bar, list[tuple[float]])


if __name__ == '__main__':
if __name__ == "__main__":
test_cast()
22 changes: 12 additions & 10 deletions src/tests/test_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyoload import Check
from pyoload import annotate

assert pyoload.__version__ == '2.0.0'
assert pyoload.__version__ == "2.0.0"


@annotate
Expand All @@ -19,13 +19,13 @@ def __init__(self: Any, bar: Checks(func=bool)) -> Any:
pass


@Check.register('test1 test2')
@Check.register("test1 test2")
def __(param, val):
print(param, val)


class IsInt(Check):
name = 'isint'
name = "isint"

def __call__(self, a, b):
return a == isinstance(b, int)
Expand All @@ -51,15 +51,17 @@ def test_check():
Checks(test1=3)(3)
Checks(test2=4)(4)
Checks(ge=2, gt=1, lt=2.1, le=2, eq=2)(2)
Checks(ge=-2.5, gt=-3, lt=-2, le=2, eq=-2.5)(-2.5)
Checks(len=(2, 5))('abcd')
Checks(type=dict[str | int, tuple[int]])({
'#': (12,),
20: (21, 45),
})
print(Checks(ge=-2.5, gt=-3, lt=-2, le=2, eq=-2.5)(-2.5))
Checks(len=(2, 5))("abcd")
Checks(type=dict[str | int, tuple[int]])(
{
"#": (12,),
20: (21, 45),
}
)
Checks(isinstance=float)(1.5)
Checks(isint=True)(5)


if __name__ == '__main__':
if __name__ == "__main__":
test_check()
39 changes: 15 additions & 24 deletions src/tests/test_overload.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,41 @@
import pyoload

from pyoload import Any
from pyoload import Checks
from pyoload import get_name
from pyoload import overload

assert pyoload.__version__ == '2.0.0'
assert pyoload.__version__ == "2.0.0"


@overload
def div(a: str, b: str):
return str(float(a) / float(b))


@overload(get_name(div))
def div_(a: str, b: Checks(eq=0)):
return 'Infinity'
def div(a: str, b: Checks(eq=0)):
return "Infinity"


@overload(get_name(div))
def div__(a: Any, b: Checks(eq=0)):
return NotImplemented
def div_(a: str, b: str):
print(f"{a=}, {b=} -> str")
return str(float(a) / float(b))


@div_.overload
def div___(a: str, b: int):
print(f"{a=}, {b=} -> int")
return int(float(a) / b)


@div.overload
def div____(a: float, b: float):
print(f"{a=}, {b=} -> float")
return float(float(a) / b)


def test_overload():
print(div.__pyod_overloads__, div.__pyod_overloads_name__)
print(div_.__pyod_overloads__, div_.__pyod_overloads_name__)
print(div__.__pyod_overloads__, div__.__pyod_overloads_name__)
print(div___.__pyod_overloads__, div___.__pyod_overloads_name__)
print(div____.__pyod_overloads__, div____.__pyod_overloads_name__)
assert div('4', '2') == '2.0'
assert div('3', 0) == 'Infinity'
assert div(..., 0) == NotImplemented
assert div('4', 2) == 2
assert div(3.0, 1.0) == 3.0


if __name__ == '__main__':
#assert div("4", "2") == "2.0"
#assert div(..., 0) == "Infinity"
assert div("4", 2) == 2
#assert div(3.0, 1.0) == 3.0


if __name__ == "__main__":
test_overload()
8 changes: 4 additions & 4 deletions src/tests/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pyoload


assert pyoload.__version__ == '2.0.0'
assert pyoload.__version__ == "2.0.0"


@annotate
Expand All @@ -13,12 +13,12 @@ def odd(a: Values(range(10))) -> bool:


def test_values():
assert odd(3), '3 reported not odd'
assert not odd(2), '2 reported odd'
assert odd(3), "3 reported not odd"
assert not odd(2), "2 reported odd"

try:
odd(10)
except AnnotationError:
pass
else:
raise AssertionError('Values did not crash')
raise AssertionError("Values did not crash")

0 comments on commit 6877d37

Please sign in to comment.