python-mastery/validate.py
2024-01-07 15:16:28 -06:00

152 lines
3.8 KiB
Python

import inspect
from functools import wraps
class Validator:
def __init__(self, name=None):
self.name = name
@classmethod
def check(cls, value):
return value
def __set__(self, instance, value):
instance.__dict__[self.name] = self.check(value)
def __set_name__(self, cls, name):
self.name = name
class Typed(Validator):
expected_type = object
@classmethod
def check(cls, value):
if not isinstance(value, cls.expected_type):
raise TypeError(f"expected {cls.expected_type}")
return super().check(value)
class Integer(Typed):
expected_type = int
class Float(Typed):
expected_type = float
class String(Typed):
expected_type = str
class Positive(Validator):
@classmethod
def check(cls, value):
if value < 0:
raise ValueError("Expected >= 0")
return super().check(value)
class NonEmpty(Validator):
@classmethod
def check(cls, value):
if len(value) == 0:
return ValueError("Must be non-empty")
return super().check(value)
class PositiveInteger(Integer, Positive):
pass
class PositiveFloat(Float, Positive):
pass
class NonEmptyString(String, NonEmpty):
pass
def validated(func):
sig = inspect.signature(func)
annotations = dict(func.__annotations__)
retcheck = annotations.pop('return', None)
@wraps(func)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
errors = []
for name, validator in annotations.items():
try:
validator.check(bound.arguments[name])
except Exception as e:
errors.append(f' {name}: {e}')
if errors:
raise TypeError('Bad Arguments\n' + '\n'.join(errors))
result = func(*args, **kwargs)
if retcheck:
try:
retcheck.check(result)
except Exception as e:
raise TypeError(f'Bad return: {e}') from None
return result
return wrapper
def enforce(**outerkwargs):
def enforced(func):
sig = inspect.signature(func)
retcheck = outerkwargs.pop('return_', None)
annotations = outerkwargs
@wraps(func)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
errors = []
for name, validator in annotations.items():
try:
validator.check(bound.arguments[name])
except Exception as e:
errors.append(f' {name}: {e}')
if errors:
raise TypeError('Bad Arguments\n' + '\n'.join(errors))
result = func(*args, **kwargs)
if retcheck:
try:
retcheck.check(result)
except Exception as e:
raise TypeError(f'Bad return: {e}') from None
return result
return wrapper
return enforced
class ValidatedFunction:
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
bound = inspect.signature(self.func).bind(*args, **kwargs)
if hasattr(self.func, '__annotations__'):
for arg in bound.arguments:
if arg in self.func.__annotations__ and issubclass(
self.func.__annotations__[arg], Validator
):
self.func.__annotations__[arg].check(bound.arguments[arg])
result = self.func(*args, **kwargs)
return result
if __name__ == '__main__':
@validated
def add(x: Integer, y: Integer):
return x + y
@validated
def power(x: Integer, y: Integer):
return x ** y
@enforce(x=Integer, y=Integer, return_=Integer)
def mult(x, y):
return x * y