diff --git a/src/pyoload/__init__.py b/src/pyoload/__init__.py index 90acee4..f522072 100644 --- a/src/pyoload/__init__.py +++ b/src/pyoload/__init__.py @@ -32,6 +32,10 @@ from typing import Type NoneType = type(None) +try: + from types import UnionType +except ImportError: + UnionType = Union class AnnotationError(ValueError): @@ -467,6 +471,8 @@ def cast(val: Any, totype: Any) -> Any: :returns: An instance of the casting type """ + if totype == Any: + return val if isinstance(totype, GenericAlias): args = get_args(totype) if get_origin(totype) == dict: @@ -478,7 +484,7 @@ def cast(val: Any, totype: Any) -> Any: else: sub = args[0] return get_origin(totype)([Cast.cast(v, sub) for v in val]) - if get_origin(totype) is Union: + if get_origin(totype) is Union or get_origin(totype) is UnionType: errors = [] for subtype in get_args(totype): try: @@ -638,6 +644,10 @@ def resove_annotations(obj: Callable) -> None: :returns: None """ + if not hasattr(obj, '__annotations__'): + raise AnnotationResolutionError( + f"object {obj=!r} does not have `.__annotations__`", + ) if isclass(obj) or hasattr(obj, "__class__"): for k, v in obj.__annotations__.items(): if isinstance(v, str): @@ -666,7 +676,7 @@ def resove_annotations(obj: Callable) -> None: f"globals: {obj.__globals__}", ) from e else: - raise AnnotationError(f"unknown resolution method for {obj}") + raise AnnotationResolutionError(f"unknown resolution method for {obj}") def annotate( diff --git a/src/tests/logs.yaml b/src/tests/logs.yaml index 41326ed..9c71659 100644 --- a/src/tests/logs.yaml +++ b/src/tests/logs.yaml @@ -1,7 +1,7 @@ -type_match vs isinstance on int:True 1.6092411ms -type_match vs isinstance on int:False 1.4381614ms -type_match on dict[str, int]*50:True 271.9251959ms -type_match on dict[str, int]*50:False 17.4513811ms -Cast str->int: 7.204096ms -Cast complex->int | str: 0.43346336ms -Cast dict[int,list[str]*10]*10->dict[str,tuple[float]]: 8.15403988ms +type_match vs isinstance on int:True 1.3719336ms +type_match vs isinstance on int:False 1.0354459999999999ms +type_match on dict[str, int]*50:True 267.45216439999996ms +type_match on dict[str, int]*50:False 17.6830167ms +Cast str->int: 10.8046044ms +Cast complex->int | str: 0.51518408ms +Cast dict[int,list[str]*10]*10->dict[str,tuple[float]]: 12.7969466ms diff --git a/src/tests/test_annotate.py b/src/tests/test_annotate.py index 14bf4e5..45508e4 100644 --- a/src/tests/test_annotate.py +++ b/src/tests/test_annotate.py @@ -1,10 +1,12 @@ import pyoload +from pyoload import AnnotationResolutionError from pyoload import Cast from pyoload import annotable from pyoload import annotate from pyoload import is_annotable from pyoload import is_annoted +from pyoload import resove_annotations from pyoload import type_match from pyoload import unannotable from pyoload import unannotate @@ -72,6 +74,24 @@ def footy(a: 'Nothing here'): else: raise Exception() + try: + resove_annotations(None) + except AnnotationResolutionError: + pass + else: + raise Exception() + + @annotate + def fooar(a: 'str', b: 'int'): + pass + fooar('4', 3) + try: + fooar('4', '4') + except Exception: + pass + else: + raise Exception() + if __name__ == "__main__": test_annotate() diff --git a/src/tests/test_cast.py b/src/tests/test_cast.py index cf12bfc..3040d66 100644 --- a/src/tests/test_cast.py +++ b/src/tests/test_cast.py @@ -33,6 +33,9 @@ def test_cast(): q.foo = {1234: {"5", 16j}} assert type_match(q.foo, dict[str, tuple[Union[int, str]]]) assert type_match(q.bar, list[tuple[float]]) + assert Cast(dict[int])({'3': '7'}) == {3: '7'} + assert Cast(dict[Any, int])({'3': '7'}) == {'3': 7} + assert Cast(tuple[int | str])(('3', 2.5, 1j, '/6')) == (3, 2, '1j', '/6') try: @annotate diff --git a/src/tests/test_check.py b/src/tests/test_check.py index 4a01c13..6a348b8 100644 --- a/src/tests/test_check.py +++ b/src/tests/test_check.py @@ -131,12 +131,11 @@ def test_check(): if pyoload.get_name(check).split(".")[0] == "tests": continue pyoload.Checks(**{name: 3})(11) + pyoload.Checks(**{name: 11})(3) except pyoload.Check.CheckError: - if name in ('ge', 'gt'): - raise Exception() + pass else: - if name not in ('ge', 'gt'): - raise Exception(name, check) + raise Exception(name, check) pyoload.Checks(len=3)((1, 2, 3)) pyoload.Checks(len=2)((1, 2)) pyoload.Checks(len=slice(3, None))((1, 2, 3))