diff --git a/python/src/writeYourProgram.py b/python/src/writeYourProgram.py index 6a0ce557..1f279943 100644 --- a/python/src/writeYourProgram.py +++ b/python/src/writeYourProgram.py @@ -33,6 +33,13 @@ def _debug(s): def _patchDataClass(cls, mutable): fieldNames = [f.name for f in dataclasses.fields(cls)] setattr(cls, EQ_ATTRS_ATTR, fieldNames) + + if hasattr(cls, '__annotations__'): + # add annotions for type checked constructor. + cls.__init__.__annotations__ = cls.__annotations__ + cls.__init__.__original = cls # mark class as source of annotation + cls.__init__ = untypy.typechecked(cls.__init__) + if mutable: # prevent new fields being added fields = set(fieldNames)