Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed May 1, 2024
1 parent 25a17c7 commit 95261b3
Showing 1 changed file with 191 additions and 0 deletions.
191 changes: 191 additions & 0 deletions src/annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from typing import Any, GenericAlias, Union
from functools import wraps, partial
from inspect import isclass

class AnnotationError(ValueError):
pass
class AnnotationErrors(AnnotationError):
pass
class InternalAnnotationError(Exception):
pass
class CastingError(TypeError):
pass
class OverloadError(TypeError):
pass
class AnnotationResolutionError(AnnotationError):
_raise = False

class Values(tuple):
"wrapper class in case of several value"
def __call__(self, val):
return val in self
def __str__(self):
return "Values(" + ", ".join(map(repr, self)) + ")"
__repr__ = __str__

def get_name(funcOrCls):
return funcOrCls.__module__+"."+funcOrCls.__qualname__
class TypeChecker:
def __init__(self, func):
if not callable(func):
raise TypeError(self.__class__.__init__.__qualname__)
self.func = func
def __call__(self, val):
try:
return self.func()
except Exception as e:
raise AnnotationError(
f"{type(e)} while using typechecker method: {get_name(self.func)}"+
f"\n{e!s}"
) from e
class Cast:
def __init__(self, type):
self.type = type
def __call__(self, val):
try:
return self.type(val)
except Exception as e:
raise CastingError(
f"Exception while casting: {val!r} to {self.type}"
) from e

def typeMatch(val, spec):
if spec == Any or spec == None:
return True
if isinstance(spec, Values):
return spec(val)
elif isinstance(spec, TypeChecker):
return spec(val)
elif isinstance(spec, GenericAlias):
if not isinstance(val, spec.__origin__):
return False
sub = Union[*spec.__args__]
for val in val:
if not typeMatch(val, sub):
return False
else:
return True
else:
return isinstance(val, spec)
def resolveAnnotations(anno, np, scope=None):
for k, v in anno.items():
if isinstance(v, str):
try:
anno[k] = eval(v, np, np)
except Exception as e:
raise AnnotationResolutionError(
f"Exception: {e!s} while resolving annotation {v!r} of {scope}"
) from e
def annotate(func, oload=False):
"decorator annotates wrapped function"
if isclass(func):
return annotateClass(func)
anno = func.__annotations__
@wraps(func)
def wrapper(*args, **kw):
names = tuple(anno.keys())
if any(isinstance(x, str) for x in anno.values()):
resolveAnnotations(anno, func.__globals__, get_name(func))
vals = {}
try:
if func.__defaults__:
for i, v in enumerate(reversed(func.__defaults__)):
vals[names[-1-i]] = v
for i, v in enumerate(args):
vals[names[i]] = v
vals.update(kw)
except IndexError as e:
raise AnnotationError(
f"Was function {get_name(func)} properly annotated?"
) from e

errors = []
for k, v in vals.items():
#(k, v)
if isinstance(anno[k], Cast):
vals[k] = anno[k](v)
#(" casting", k, "to", anno[k])
continue
if not typeMatch(v, anno[k]):
if oload:
raise InternalAnnotationError()
errors.append(
AnnotationError(
f"Value: {v!r} does not match annotation: {anno[k]!r} for argument {k!r} of function {get_name(func)}"
)
)
if len(errors) > 0:
raise AnnotationErrors(errors)

ret = func(**vals)
if "return" in anno:
if not typeMatch(ret, anno["return"]):
raise AnnotationError(
f"return value {ret!r} does not match annotation: {anno['return']} of function {get_name(func)}"
)
return ret
return wrapper
__overloads__ = {}

def overload(func, name=None):
if isinstance(func, str):
return partial(overload, name=func)
if name is None or not isinstance(name, str):
name = get_name(func)

if not name in __overloads__:
__overloads__[name] = []

__overloads__[name].append(annotate(func, True))
func.__overloads__ = __overloads__[name]
@wraps(func)
def wrapper(*args, **kw):
for f in __overloads__[name]:
try:
val = f(*args, **kw)
except InternalAnnotationError as e:
continue
else:
break
else:
raise OverloadError(
f"No overload of function: {get_name(func)} matches types of arguments"
)
return val
return wrapper

def annotateClass(cls):
if isinstance(cls, bool):
return partial(annotate, recur=cls)
recur = not hasattr(cls, "__annotate_norecur__")
setter = cls.__setattr__
anno = cls.__annotations__
if recur:
for x in dir(cls):
if hasattr(getattr(cls, x), "__annotations__"):
setattr(
cls,
x,
annotate(
getattr(
cls,
x
)
)
)
def new_setter(self, name, value):
if any(isinstance(x, str) for x in anno.values()):
resolveAnnotations(anno, globals(), get_name(cls))

if not name in anno:
anno[name] = type(value)
if not typeMatch(value, anno[name]):
raise AnnotationError(
f"value {value!r} does not match annotation of attribute: {name!r}:{anno[name]!r} of object of class {get_name(cls)}"
)
else:
return setter(self, name, value)
cls.__setattr__ = new_setter
return cls


0 comments on commit 95261b3

Please sign in to comment.