Skip to content

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed May 10, 2024
1 parent 4579395 commit 8aca8b3
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions src/pyoload/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, GenericAlias, Union
from types import UnionType
from functools import wraps, partial
from inspect import isclass

Expand Down Expand Up @@ -62,32 +63,36 @@ def __call__(self, val):
class Cast:
@staticmethod
def cast(val, type):
if issubclass(type, Union):
if isinstance(type, UnionType):
type = type.__args__
if isinstance(type, tuple):
e = None
for x in type:
try:
return Cast.cast(val, x)
except Exception:
return Cast.cast(val, type)
except Exception as e:
pass
else:
raise CastingError()
return type(val) if not isinstance(val, type) else val
raise e
else:
print(f"casting {val!r} to {type!r}")
return type(val) if not isinstance(val, type) else val

def __init__(self, type):
self.type = type

def __call__(self, val):
try:
return Cast.cast(self.type, val)
return Cast.cast(val, self.type)
except Exception as e:
raise e from e
raise CastingError(
f'Exception while casting: {val!r} to {self.type}',
f'Exception({e}) while casting: {val!r} to {self.type}',
) from e


def typeMatch(val, spec):
if spec == Any or spec is None:
if spec == Any or spec is None or val is None:
return True
if isinstance(val, tuple):
return isinstance(val, tuple)
Expand Down Expand Up @@ -124,6 +129,8 @@ def annotate(func, oload=False):
if isclass(func):
return annotateClass(func)
anno = func.__annotations__
if len(anno) == 0:
return func

@wraps(func)
def wrapper(*args, **kw):
Expand Down Expand Up @@ -225,14 +232,15 @@ def new_setter(self, name, value):
resolveAnnotations(anno, globals(), get_name(cls))

if name not in anno:
anno[name] = type(value)
if not typeMatch(value, anno[name]):
if value is not None:
anno[name] = type(value)
elif not typeMatch(value, anno[name]):
raise AnnotationError(
f'value {value!r} does not match annotation' +
'of attribute: {name!r}:{anno[name]!r} of object of class {get_name(cls)}',
f'of attribute: {name!r}:{anno[name]!r} of object of class {get_name(cls)}',
)
else:
return setter(self, name, value)

return setter(self, name, value)
cls.__setattr__ = new_setter
return cls

Expand Down

0 comments on commit 8aca8b3

Please sign in to comment.