Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed Jun 9, 2024
1 parent 17e8b01 commit 7d81ccb
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
14 changes: 12 additions & 2 deletions src/pyoload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from typing import Type

NoneType = type(None)
try:
from types import UnionType
except ImportError:
UnionType = Union


class AnnotationError(ValueError):
Expand Down Expand Up @@ -467,6 +471,8 @@ def cast(val: Any, totype: Any) -> Any:
:returns: An instance of the casting type
"""
if totype == Any:
return val
if isinstance(totype, GenericAlias):
args = get_args(totype)
if get_origin(totype) == dict:
Expand All @@ -478,7 +484,7 @@ def cast(val: Any, totype: Any) -> Any:
else:
sub = args[0]
return get_origin(totype)([Cast.cast(v, sub) for v in val])
if get_origin(totype) is Union:
if get_origin(totype) is Union or get_origin(totype) is UnionType:
errors = []
for subtype in get_args(totype):
try:
Expand Down Expand Up @@ -638,6 +644,10 @@ def resove_annotations(obj: Callable) -> None:
:returns: None
"""
if not hasattr(obj, '__annotations__'):
raise AnnotationResolutionError(
f"object {obj=!r} does not have `.__annotations__`",
)
if isclass(obj) or hasattr(obj, "__class__"):
for k, v in obj.__annotations__.items():
if isinstance(v, str):
Expand Down Expand Up @@ -666,7 +676,7 @@ def resove_annotations(obj: Callable) -> None:
f"globals: {obj.__globals__}",
) from e
else:
raise AnnotationError(f"unknown resolution method for {obj}")
raise AnnotationResolutionError(f"unknown resolution method for {obj}")


def annotate(
Expand Down
14 changes: 7 additions & 7 deletions src/tests/logs.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
type_match vs isinstance on int:True 1.6092411ms
type_match vs isinstance on int:False 1.4381614ms
type_match on dict[str, int]*50:True 271.9251959ms
type_match on dict[str, int]*50:False 17.4513811ms
Cast str->int: 7.204096ms
Cast complex->int | str: 0.43346336ms
Cast dict[int,list[str]*10]*10->dict[str,tuple[float]]: 8.15403988ms
type_match vs isinstance on int:True 1.3719336ms
type_match vs isinstance on int:False 1.0354459999999999ms
type_match on dict[str, int]*50:True 267.45216439999996ms
type_match on dict[str, int]*50:False 17.6830167ms
Cast str->int: 10.8046044ms
Cast complex->int | str: 0.51518408ms
Cast dict[int,list[str]*10]*10->dict[str,tuple[float]]: 12.7969466ms
20 changes: 20 additions & 0 deletions src/tests/test_annotate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pyoload

from pyoload import AnnotationResolutionError
from pyoload import Cast
from pyoload import annotable
from pyoload import annotate
from pyoload import is_annotable
from pyoload import is_annoted
from pyoload import resove_annotations
from pyoload import type_match
from pyoload import unannotable
from pyoload import unannotate
Expand Down Expand Up @@ -72,6 +74,24 @@ def footy(a: 'Nothing here'):
else:
raise Exception()

try:
resove_annotations(None)
except AnnotationResolutionError:
pass
else:
raise Exception()

@annotate
def fooar(a: 'str', b: 'int'):
pass
fooar('4', 3)
try:
fooar('4', '4')
except Exception:
pass
else:
raise Exception()


if __name__ == "__main__":
test_annotate()
3 changes: 3 additions & 0 deletions src/tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def test_cast():
q.foo = {1234: {"5", 16j}}
assert type_match(q.foo, dict[str, tuple[Union[int, str]]])
assert type_match(q.bar, list[tuple[float]])
assert Cast(dict[int])({'3': '7'}) == {3: '7'}
assert Cast(dict[Any, int])({'3': '7'}) == {'3': 7}
assert Cast(tuple[int | str])(('3', 2.5, 1j, '/6')) == (3, 2, '1j', '/6')

try:
@annotate
Expand Down
7 changes: 3 additions & 4 deletions src/tests/test_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,11 @@ def test_check():
if pyoload.get_name(check).split(".")[0] == "tests":
continue
pyoload.Checks(**{name: 3})(11)
pyoload.Checks(**{name: 11})(3)
except pyoload.Check.CheckError:
if name in ('ge', 'gt'):
raise Exception()
pass
else:
if name not in ('ge', 'gt'):
raise Exception(name, check)
raise Exception(name, check)
pyoload.Checks(len=3)((1, 2, 3))
pyoload.Checks(len=2)((1, 2))
pyoload.Checks(len=slice(3, None))((1, 2, 3))
Expand Down

0 comments on commit 7d81ccb

Please sign in to comment.