Skip to content

Commit

Permalink
Create __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-morel committed May 2, 2024
1 parent 1dffc6b commit 6266fa6
Showing 1 changed file with 224 additions and 0 deletions.
224 changes: 224 additions & 0 deletions src/pyoload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Any, GenericAlias, Union
from functools import wraps, partial
from inspect import isclass


class AnnotationError(ValueError):
pass


class AnnotationErrors(AnnotationError):
pass


class InternalAnnotationError(Exception):
pass


class CastingError(TypeError):
pass


class OverloadError(TypeError):
pass


class AnnotationResolutionError(AnnotationError):
_raise = False


class Values(tuple):
"""wrapper class in case of several value"""

def __call__(self, val):
return val in self

def __str__(self):
return 'Values(' + ', '.join(map(repr, self)) + ')'

__repr__ = __str__


def get_name(funcOrCls):
return funcOrCls.__module__ + '.' + funcOrCls.__qualname__


class TypeChecker:
def __init__(self, func):
if not callable(func):
raise TypeError(self.__class__.__init__.__qualname__)
self.func = func

def __call__(self, val):
try:
return self.func()
except Exception as e:
raise AnnotationError(
f'{type(e)} while using typechecker method: {get_name(self.func)}' +
f'\n{e!s}',
) from e


class Cast:
def __init__(self, type):
self.type = type

def __call__(self, val):
try:
return self.type(val)
except Exception as e:
raise CastingError(
f'Exception while casting: {val!r} to {self.type}',
) from e


def typeMatch(val, spec):
if spec == Any or spec is None:
return True
if isinstance(spec, Values):
return spec(val)
elif isinstance(spec, TypeChecker):
return spec(val)
elif isinstance(spec, GenericAlias):
if not isinstance(val, spec.__origin__):
return False
sub = Union[*spec.__args__]
for val in val:
if not typeMatch(val, sub):
return False
else:
return True
else:
return isinstance(val, spec)


def resolveAnnotations(anno, np, scope=None):
for k, v in anno.items():
if isinstance(v, str):
try:
anno[k] = eval(v, np, np)
except Exception as e:
raise AnnotationResolutionError(
f'Exception: {e!s} while resolving annotation {v!r} of {scope}',
) from e


def annotate(func, oload=False):
"""decorator annotates wrapped function"""
if isclass(func):
return annotateClass(func)
anno = func.__annotations__

@wraps(func)
def wrapper(*args, **kw):
names = tuple(anno.keys())
if any(isinstance(x, str) for x in anno.values()):
resolveAnnotations(anno, func.__globals__, get_name(func))
vals = {}
try:
if func.__defaults__:
for i, v in enumerate(reversed(func.__defaults__)):
vals[names[-1 - i]] = v
for i, v in enumerate(args):
vals[names[i]] = v
vals.update(kw)
except IndexError as e:
raise AnnotationError(
f'Was function {get_name(func)} properly annotated?',
) from e

errors = []
for k, v in vals.items():
if isinstance(anno[k], Cast):
vals[k] = anno[k](v)
continue
if not typeMatch(v, anno[k]):
if oload:
raise InternalAnnotationError()
errors.append(
AnnotationError(
f'Value: {v!r} does not match annotation: {anno[k]!r}' +
f' for argument {k!r} of function {get_name(func)}',
),
)
if len(errors) > 0:
raise AnnotationErrors(errors)

ret = func(**vals)
if 'return' in anno:
if not typeMatch(ret, anno['return']):
raise AnnotationError(
f"return value {ret!r} does not match annotation: {anno['return']} of function {get_name(func)}",
)
return ret
return wrapper


__overloads__ = {}


def overload(func, name=None):
if isinstance(func, str):
return partial(overload, name=func)
if name is None or not isinstance(name, str):
name = get_name(func)
if name not in __overloads__:
__overloads__[name] = []
__overloads__[name].append(annotate(func, True))
func.__overloads__ = __overloads__[name]

@wraps(func)
def wrapper(*args, **kw):
for f in __overloads__[name]:
try:
val = f(*args, **kw)
except InternalAnnotationError:
continue
else:
break
else:
raise OverloadError(
f'No overload of function: {get_name(func)} matches types of arguments',
)
return val
return wrapper


def annotateClass(cls):
if isinstance(cls, bool):
return partial(annotate, recur=cls)
recur = not hasattr(cls, '__annotate_norecur__')
setter = cls.__setattr__
anno = cls.__annotations__
if recur:
for x in dir(cls):
if hasattr(getattr(cls, x), '__annotations__'):
setattr(
cls,
x,
annotate(
getattr(
cls,
x,
),
),
)

def new_setter(self, name, value):
if any(isinstance(x, str) for x in anno.values()):
resolveAnnotations(anno, globals(), get_name(cls))

if name not in anno:
anno[name] = type(value)
if not typeMatch(value, anno[name]):
raise AnnotationError(
f'value {value!r} does not match annotation' +
'of attribute: {name!r}:{anno[name]!r} of object of class {get_name(cls)}',
)
else:
return setter(self, name, value)
cls.__setattr__ = new_setter
return cls


__version__ = '1.0.0'

0 comments on commit 6266fa6

Please sign in to comment.