diff --git a/stock.py b/stock.py index c30e3a2..b97a987 100644 --- a/stock.py +++ b/stock.py @@ -4,6 +4,9 @@ from structure import Structure class Stock(Structure): _fields = ('name', 'shares', 'price') + def __init__(self, name, shares, price): + self._init() + @property def cost(self): return self.shares * self.price diff --git a/structure.py b/structure.py index c3a89ed..8ece04c 100644 --- a/structure.py +++ b/structure.py @@ -1,11 +1,15 @@ +import sys + + class Structure: _fields = () - def __init__(self, *args): - if len(args) != len(self._fields): - raise TypeError(f'Expected {len(self._fields)} arguments') - for pos, field in enumerate(self._fields): - setattr(self, field, args[pos]) + @staticmethod + def _init(): + locs = sys._getframe(1).f_locals + self = locs.pop('self') + for name, val in locs.items(): + setattr(self, name, val) def __repr__(self): args = map(lambda field: f"{field}={getattr(self, field)!r}", self._fields)