Skip to content

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed Jun 5, 2024
1 parent ade4966 commit 6e87d45
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions src/pyoload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,20 +344,21 @@ def cast(val: Any, totype: Any) -> Any:
:returns: An instance of the casting type
"""
if isinstance(totype, GenericAlias):
if totype.__origin__ == dict:
if len(totype.__args__) == 2:
kt, vt = totype.__args__
elif len(totype.__args__) == 1:
kt, vt = Any, totype.__args__[1]
if get_origin(totype) == dict:
args = get_args(totype)
if len(args) == 2:
kt, vt = args
elif len(args) == 1:
kt, vt = args[0], Any
return {
Cast.cast(k, kt): Cast.cast(v, vt) for k, v in val.items()
}
else:
sub = totype.__args__[0]
return totype.__origin__([Cast.cast(v, sub) for v in val])
sub = args[0]
return get_origin(totype)([Cast.cast(v, sub) for v in val])
if get_origin(totype) is Union:
errors = []
for subtype in totype.__args__:
for subtype in get_args(totype):
try:
return Cast.cast(val, subtype)
except Exception as e:
Expand Down Expand Up @@ -513,7 +514,7 @@ def typeMatch(val: Any, spec: Any) -> bool:
return True


def resolveAnnotations(obj: Type | Callable) -> None:
def resolveAnnotations(obj: Callable) -> None:
"""
Evaluates all the stringized annotations of the argument
Expand Down Expand Up @@ -708,7 +709,7 @@ def wrapper(*args, **kw):
return wrapper


def annotateClass(cls: Any):
def annotateClass(cls: Any, recur: bool = True):
"""
Annotates a class object, wrapping and replacing over it's __setattr__
and typechecking over each attribute assignment.
Expand All @@ -717,11 +718,12 @@ def annotateClass(cls: Any):
it recursively annotates the classes methods except `__pyod_norecur__`
attribute is defines
"""

if isinstance(cls, bool):
return partial(annotateClass, recur=cls)
if not hasattr(cls, "__annotations__"):
cls.__annotations__ = {}
if isinstance(cls, bool):
return partial(annotate, recur=cls)
recur = not hasattr(cls, "__pyod_norecur__")
recur = not hasattr(cls, "__pyod_norecur__") and recur
setter = cls.__setattr__
if recur:
for x in dir(cls):
Expand All @@ -739,7 +741,7 @@ def annotateClass(cls: Any):

@wraps(cls.__setattr__)
def new_setter(self: Any, name: str, value: Any) -> Any:
if any(isinstance(x, str) for x in self.__annotations__.values()):
if str in map(type, self.__annotations__.values()):
resolveAnnotations(self)

if name not in self.__annotations__:
Expand Down

0 comments on commit 6e87d45

Please sign in to comment.