diff --git a/validate.py b/validate.py index be8a788..986d678 100644 --- a/validate.py +++ b/validate.py @@ -1,3 +1,6 @@ +import inspect + + class Validator: def __init__(self, name=None): self.name = name @@ -63,37 +66,23 @@ class NonEmptyString(String, NonEmpty): pass -class Stock: - name = String() - shares = PositiveInteger() - price = PositiveFloat() +class ValidatedFunction: + def __init__(self, func): + self.func = func - def __init__(self, name, shares, price): - self.name = name - self.shares = shares - self.price = price + def __call__(self, *args, **kwargs): + bound = inspect.signature(self.func).bind(*args, **kwargs) + if hasattr(self.func, '__annotations__'): + for arg in bound.arguments: + if arg in self.func.__annotations__ and issubclass( + self.func.__annotations__[arg], Validator + ): + self.func.__annotations__[arg].check(bound.arguments[arg]) + result = self.func(*args, **kwargs) + return result - def __repr__(self): - return f"Stock({self.name!r}, {self.shares!r}, {self.price!r})" - def __eq__(self, other): - if not isinstance(other, Stock): - return False - return (self.name, self.shares, self.price) == ( - other.name, - other.shares, - other.price, - ) - - @classmethod - def from_row(cls, row): - return cls(*row) - - @property - def cost(self): - return self.shares * self.price - - def sell(self, num): - self.shares -= num - if self.shares < 0: - self.shares = 0 +if __name__ == '__main__': + def add(x: Integer, y: Integer): + return x + y + add = ValidatedFunction(add)