Skip to content

Commit

Permalink
adding support for class annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed May 22, 2024
1 parent ae1b57b commit 5b2ad7c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 39 deletions.
13 changes: 11 additions & 2 deletions pyoload.sublime-project
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
"folders":
[
{
"path": "."
"path": ".",
}
]
],
"build_systems":
[
{
"file_regex": "^[ ]*File \"(...*?)\", line ([0-9]*)",
"name": "Anaconda Python Builder",
"selector": "source.python",
"shell_cmd": "\"python\" -u \"$file\""
}
],
}
21 changes: 20 additions & 1 deletion src/del.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,21 @@
from pyoload import *
a = d
import pyoload

assert pyoload.__version__ == '1.0.2'

@annotate
class foo:
fa:'str'
def __init__(self: 'foo', bar: Cast(dict[int, list[float]])):
self.foo = bar
print(bar)
self.fa = 3


b = foo({'1':['1.0', 3]})



"""
mentor-no = 694190032
"""
115 changes: 79 additions & 36 deletions src/pyoload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import wraps, partial
from inspect import isclass

import sys


class AnnotationError(ValueError):
pass
Expand Down Expand Up @@ -71,21 +73,33 @@ def __call__(self, val):

class Cast:
@staticmethod
def cast(val, type):
if isinstance(type, UnionType):
type = type.__args__
if isinstance(type, tuple):
e = None
for x in type:
def cast(val, totype):
if isinstance(totype, GenericAlias):
if totype.__origin__ == dict:
if len(totype.__args__) == 2:
kt, vt = totype.__args__
elif len(totype.__args__) == 1:
kt, vt = Any, totype.__args__[1]
print(f"{totype=} {kt=} {vt=}")
return {
Cast.cast(k, kt): Cast.cast(v, vt) for k, v in val.items()
}
else:
sub = totype.__args__[0]
return totype.__origin__([
Cast.cast(v, sub) for v in val
])
if isinstance(totype, UnionType):
errors = []
for subtype in totype.__args__:
try:
return Cast.cast(val, type)
return Cast.cast(val, subtype)
except Exception as e:
pass
errors.append(e)
else:
raise e
else:
print(f"casting {val!r} to {type!r}")
return type(val) if not isinstance(val, type) else val
return totype(val) if not isinstance(val, totype) else val

def __init__(self, type):
self.type = type
Expand All @@ -94,7 +108,6 @@ def __call__(self, val):
try:
return Cast.cast(val, self.type)
except Exception as e:
raise e from e
raise CastingError(
f'Exception({e}) while casting: {val!r} to {self.type}',
) from e
Expand All @@ -103,34 +116,63 @@ def __call__(self, val):
def typeMatch(val, spec):
if spec == Any or spec is None or val is None:
return True
if isinstance(val, tuple):
return isinstance(val, tuple)
if isinstance(spec, Values):
return spec(val)
elif isinstance(spec, Validator):
return spec(val)
elif isinstance(spec, GenericAlias):
if not isinstance(val, spec.__origin__):
return False
sub = spec.__args__
for val in val:
if not typeMatch(val, sub):
return False

if spec.__origin__ == dict:
if len(spec.__args__) == 2:
kt, vt = spec.__args__
elif len(spec.__args__) == 1:
kt, vt = Any, spec.__args__[1]
else:
return True

for k, v in val.items():
if not typeMatch(k, kt) or not typeMatch(v, vt):
return False
else:
return True
else:
return True
sub = spec.__args__[0]
for val in val:
if not typeMatch(val, sub):
return False
else:
return True
else:
return isinstance(val, spec)


def resolveAnnotations(anno, np, scope=None):
for k, v in anno.items():
if isinstance(v, str):
try:
anno[k] = eval(v, np, np)
except Exception as e:
raise AnnotationResolutionError(
f'Exception: {e!s} while resolving annotation {v!r} of {scope}',
) from e
def get_module(obj):
return sys.modules[obj.__module__]

def resolveAnnotations(obj):
if isclass(obj):
for k, v in obj.__annotations__.items():
if isinstance(v, str):
try:
obj.__annotations__[k] = eval(v, get_module(obj).__globals__, dict(vars(obj)))
except Exception as e:
raise AnnotationResolutionError(
f'Exception: {k!s} while resolving'
f' annotation {v!r} of object {obj!r}',
) from e
elif callable(obj):
for k, v in obj.__annotations__.items():
if isinstance(v, str):
try:
obj.__annotations__[k] = eval(v, obj.__globals__)
except Exception as e:
raise AnnotationResolutionError(
f'Exception: {k!s} while resolving'
f' annotation {v!r} of function {obj!r}',
f'globals: {obj.__globals__}'
) from e


def annotate(func, oload=False):
Expand All @@ -145,7 +187,7 @@ def annotate(func, oload=False):
def wrapper(*args, **kw):
names = tuple(anno.keys())
if any(isinstance(x, str) for x in anno.values()):
resolveAnnotations(anno, func.__globals__, get_name(func))
resolveAnnotations(func)
vals = {}
try:
if func.__defaults__:
Expand Down Expand Up @@ -217,11 +259,12 @@ def wrapper(*args, **kw):


def annotateClass(cls):
if not hasattr(cls, '__annotations__'):
cls.__annotations__ = {}
if isinstance(cls, bool):
return partial(annotate, recur=cls)
recur = not hasattr(cls, '__annotate_norecur__')
setter = cls.__setattr__
anno = cls.__annotations__
if recur:
for x in dir(cls):
if hasattr(getattr(cls, x), '__annotations__'):
Expand All @@ -237,21 +280,21 @@ def annotateClass(cls):
)

def new_setter(self, name, value):
if any(isinstance(x, str) for x in anno.values()):
resolveAnnotations(anno, globals(), get_name(cls))
if any(isinstance(x, str) for x in self.__annotations__.values()):
resolveAnnotations(self)

if name not in anno:
if name not in self.__annotations__:
if value is not None:
anno[name] = type(value)
elif not typeMatch(value, anno[name]):
self.__annotations__[name] = type(value)
elif not typeMatch(value, self.__annotations__[name]):
raise AnnotationError(
f'value {value!r} does not match annotation' +
f'of attribute: {name!r}:{anno[name]!r} of object of class {get_name(cls)}',
f'of attribute: {name!r}:{self.__annotations__[name]!r} of object of class {get_name(cls)}',
)

return setter(self, name, value)
cls.__setattr__ = new_setter
return cls


__version__ = '1.0.1'
__version__ = '1.0.2'

0 comments on commit 5b2ad7c

Please sign in to comment.