From b1d0603840f80cc4f22302575745ef72f47cc222 Mon Sep 17 00:00:00 2001 From: ken-morel Date: Sat, 8 Jun 2024 12:22:14 +0100 Subject: [PATCH] ... --- src/pyoload/__init__.py | 6 +++++- src/tests/logs.yaml | 14 +++++++------- src/tests/test_annotate.py | 22 ++++++++++++++++++++++ src/tests/test_check.py | 2 ++ 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/pyoload/__init__.py b/src/pyoload/__init__.py index 9eb60ae..e01d352 100644 --- a/src/pyoload/__init__.py +++ b/src/pyoload/__init__.py @@ -675,6 +675,8 @@ def annotate( :returns: the wrapper function """ + if not hasattr(func, '__annotations__'): + return func if isclass(func): return annotateClass(func) if len(func.__annotations__) == 0: @@ -717,7 +719,7 @@ def wrapper(*pargs, **kw): if len(errors) > 0: raise AnnotationErrors(errors) - ret = func(*pargs, **kw) + ret = func(**args.arguments) if sign.return_annotation is not _empty: ann = sign.return_annotation @@ -748,10 +750,12 @@ def unannotate(func: Callable) -> Callable: def unannotable(func: Callable) -> Callable: func = unannotate(func) func.__pyod_annotable__ = False + return func def annotable(func: Callable) -> Callable: func.__pyod_annotable__ = True + return func def is_annotable(func): diff --git a/src/tests/logs.yaml b/src/tests/logs.yaml index 04c3fcc..935dc6a 100644 --- a/src/tests/logs.yaml +++ b/src/tests/logs.yaml @@ -1,7 +1,7 @@ -type_match vs isinstance on int:True 1.0719632ms -type_match vs isinstance on int:False 1.2783177ms -type_match on dict[str, int]*50:True 278.8694041ms -type_match on dict[str, int]*50:False 17.2757097ms -Cast str->int: 7.049113999999999ms -Cast complex->int | str: 0.29363386ms -Cast dict[int,list[str]*10]*10->dict[str,tuple[float]]: 8.06158008ms +type_match vs isinstance on int:True 0.9860181ms +type_match vs isinstance on int:False 1.2102531ms +type_match on dict[str, int]*50:True 286.9993945ms +type_match on dict[str, int]*50:False 17.3531196ms +Cast str->int: 7.2130632ms +Cast complex->int | str: 0.27411661ms +Cast dict[int,list[str]*10]*10->dict[str,tuple[float]]: 8.13545716ms diff --git a/src/tests/test_annotate.py b/src/tests/test_annotate.py index 75882cb..f6070b2 100644 --- a/src/tests/test_annotate.py +++ b/src/tests/test_annotate.py @@ -1,6 +1,12 @@ import pyoload +from pyoload import Cast +from pyoload import annotable from pyoload import annotate +from pyoload import is_annotable +from pyoload import is_annoted +from pyoload import unannotable +from pyoload import unannotate assert pyoload.__version__ == "2.0.0" @@ -12,8 +18,24 @@ def foo(a, b=3, c: str = "R") -> int: return 3 +def foo1(a: Cast(str)): + return a + + def test_annotate(): foo(2) + assert annotate(foo1)(3) == '3' + assert unannotate(annotate(foo1))(3) == 3 + assert annotate(unannotable(foo1))(3) == 3 + assert annotate(unannotable(foo1))(3) == 3 + assert annotate(unannotable(foo1), force=True)(3) == '3' + assert annotate(annotable(unannotable(foo1)))(3) == '3' + assert is_annotable(foo1) + assert is_annotable(annotable(foo1)) + assert is_annotable(annotable(unannotable(foo1))) + assert not is_annotable(unannotable(foo1)) + assert is_annoted(annotate(foo1, force=True)) + assert not is_annoted(foo1) if __name__ == "__main__": diff --git a/src/tests/test_check.py b/src/tests/test_check.py index d481e41..8544910 100644 --- a/src/tests/test_check.py +++ b/src/tests/test_check.py @@ -67,6 +67,8 @@ def test_check(): if pyoload.get_name(check).split('.')[0] == 'tests': continue pyoload.Checks(**{name: NotImplemented})(24) + pyoload.Checks(**{name: int})(11) + pyoload.Checks(**{name: 3})(11) except pyoload.Check.CheckError: pass else: