152 lines
3.8 KiB
Python
152 lines
3.8 KiB
Python
import inspect
|
|
from functools import wraps
|
|
|
|
|
|
class Validator:
|
|
def __init__(self, name=None):
|
|
self.name = name
|
|
|
|
@classmethod
|
|
def check(cls, value):
|
|
return value
|
|
|
|
def __set__(self, instance, value):
|
|
instance.__dict__[self.name] = self.check(value)
|
|
|
|
def __set_name__(self, cls, name):
|
|
self.name = name
|
|
|
|
|
|
class Typed(Validator):
|
|
expected_type = object
|
|
|
|
@classmethod
|
|
def check(cls, value):
|
|
if not isinstance(value, cls.expected_type):
|
|
raise TypeError(f"expected {cls.expected_type}")
|
|
return super().check(value)
|
|
|
|
|
|
class Integer(Typed):
|
|
expected_type = int
|
|
|
|
|
|
class Float(Typed):
|
|
expected_type = float
|
|
|
|
|
|
class String(Typed):
|
|
expected_type = str
|
|
|
|
|
|
class Positive(Validator):
|
|
@classmethod
|
|
def check(cls, value):
|
|
if value < 0:
|
|
raise ValueError("Expected >= 0")
|
|
return super().check(value)
|
|
|
|
|
|
class NonEmpty(Validator):
|
|
@classmethod
|
|
def check(cls, value):
|
|
if len(value) == 0:
|
|
return ValueError("Must be non-empty")
|
|
return super().check(value)
|
|
|
|
|
|
class PositiveInteger(Integer, Positive):
|
|
pass
|
|
|
|
|
|
class PositiveFloat(Float, Positive):
|
|
pass
|
|
|
|
|
|
class NonEmptyString(String, NonEmpty):
|
|
pass
|
|
|
|
|
|
def validated(func):
|
|
sig = inspect.signature(func)
|
|
annotations = dict(func.__annotations__)
|
|
retcheck = annotations.pop('return', None)
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
bound = sig.bind(*args, **kwargs)
|
|
errors = []
|
|
for name, validator in annotations.items():
|
|
try:
|
|
validator.check(bound.arguments[name])
|
|
except Exception as e:
|
|
errors.append(f' {name}: {e}')
|
|
if errors:
|
|
raise TypeError('Bad Arguments\n' + '\n'.join(errors))
|
|
result = func(*args, **kwargs)
|
|
if retcheck:
|
|
try:
|
|
retcheck.check(result)
|
|
except Exception as e:
|
|
raise TypeError(f'Bad return: {e}') from None
|
|
return result
|
|
return wrapper
|
|
|
|
|
|
def enforce(**outerkwargs):
|
|
def enforced(func):
|
|
sig = inspect.signature(func)
|
|
retcheck = outerkwargs.pop('return_', None)
|
|
annotations = outerkwargs
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
bound = sig.bind(*args, **kwargs)
|
|
errors = []
|
|
for name, validator in annotations.items():
|
|
try:
|
|
validator.check(bound.arguments[name])
|
|
except Exception as e:
|
|
errors.append(f' {name}: {e}')
|
|
if errors:
|
|
raise TypeError('Bad Arguments\n' + '\n'.join(errors))
|
|
result = func(*args, **kwargs)
|
|
if retcheck:
|
|
try:
|
|
retcheck.check(result)
|
|
except Exception as e:
|
|
raise TypeError(f'Bad return: {e}') from None
|
|
return result
|
|
return wrapper
|
|
return enforced
|
|
|
|
|
|
class ValidatedFunction:
|
|
def __init__(self, func):
|
|
self.func = func
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
bound = inspect.signature(self.func).bind(*args, **kwargs)
|
|
if hasattr(self.func, '__annotations__'):
|
|
for arg in bound.arguments:
|
|
if arg in self.func.__annotations__ and issubclass(
|
|
self.func.__annotations__[arg], Validator
|
|
):
|
|
self.func.__annotations__[arg].check(bound.arguments[arg])
|
|
result = self.func(*args, **kwargs)
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
@validated
|
|
def add(x: Integer, y: Integer):
|
|
return x + y
|
|
|
|
@validated
|
|
def power(x: Integer, y: Integer):
|
|
return x ** y
|
|
|
|
@enforce(x=Integer, y=Integer, return_=Integer)
|
|
def mult(x, y):
|
|
return x * y
|