diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py index e3c6252474bd..8251245f1cac 100644 --- a/sdks/python/apache_beam/ml/anomaly/base.py +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -154,7 +154,7 @@ def __init__( threshold_criterion: Optional[ThresholdFn] = None, **kwargs): self._model_id = model_id if model_id is not None else getattr( - self, 'spec_type', 'unknown') + self, 'spec_type', lambda: "unknown")() self._features = features self._target = target self._threshold_criterion = threshold_criterion @@ -200,7 +200,7 @@ def __init__( aggregation_strategy: Optional[AggregationFn] = None, **kwargs): if "model_id" not in kwargs or kwargs["model_id"] is None: - kwargs["model_id"] = getattr(self, 'spec_type', 'custom') + kwargs["model_id"] = getattr(self, 'spec_type', lambda: 'custom')() super().__init__(**kwargs) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 2eeb1d0de76d..3a2baf434f9b 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -27,7 +27,7 @@ import logging import os from typing import Any -from typing import ClassVar +from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -35,6 +35,7 @@ from typing import Type from typing import TypeVar from typing import Union +from typing import overload from typing import runtime_checkable from typing_extensions import Self @@ -59,7 +60,8 @@ #: `spec_type` when applying the `specifiable` decorator to an existing class. _KNOWN_SPECIFIABLE = collections.defaultdict(dict) -SpecT = TypeVar('SpecT', bound='Specifiable') +T = TypeVar('T', bound=type) +BUILTIN_TYPES_IN_SPEC = (int, float, complex, str, bytes, bytearray) def _class_to_subspace(cls: Type) -> str: @@ -104,33 +106,59 @@ class Spec(): config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict) -@runtime_checkable -class Specifiable(Protocol): - """Protocol that a specifiable class needs to implement.""" - #: The value of the `type` field in the object's spec for this class. - spec_type: ClassVar[str] - #: The raw keyword arguments passed to `__init__` method during object - #: initialization. - init_kwargs: dict[str, Any] +def _specifiable_from_spec_helper(v, _run_init): + if isinstance(v, Spec): + return Specifiable.from_spec(v, _run_init) + + if isinstance(v, List): + return [_specifiable_from_spec_helper(e, _run_init) for e in v] + + # TODO: support spec treatment for more types + if not isinstance(v, BUILTIN_TYPES_IN_SPEC): + logging.warning( + "Type %s is not a recognized supported type for the" + "specification. It will be included without conversion.", + str(type(v))) + return v - # a boolean to tell whether the original `__init__` method is called - _initialized: bool - # a boolean used by new_getattr to tell whether it is in the `__init__` method - # call - _in_init: bool - @staticmethod - def _from_spec_helper(v, _run_init): - if isinstance(v, Spec): - return Specifiable.from_spec(v, _run_init) +def _specifiable_to_spec_helper(v): + if isinstance(v, Specifiable): + return v.to_spec() - if isinstance(v, List): - return [Specifiable._from_spec_helper(e, _run_init) for e in v] + if isinstance(v, List): + return [_specifiable_to_spec_helper(e) for e in v] - return v + if inspect.isfunction(v): + if not hasattr(v, "spec_type"): + _register(v, inject_spec_type=False) + return Spec(type=_get_default_spec_type(v), config=None) + if inspect.isclass(v): + if not hasattr(v, "spec_type"): + _register(v, inject_spec_type=False) + return Spec(type=_get_default_spec_type(v), config=None) + + # TODO: support spec treatment for more types + if not isinstance(v, BUILTIN_TYPES_IN_SPEC): + logging.warning( + "Type %s is not a recognized supported type for the" + "specification. It will be included without conversion.", + str(type(v))) + return v + + +@runtime_checkable +class Specifiable(Protocol): + """Protocol that a specifiable class needs to implement.""" @classmethod - def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]: + def spec_type(cls) -> str: + pass + + @classmethod + def from_spec(cls, + spec: Spec, + _run_init: bool = True) -> Union[Self, type[Self]]: """Generate a `Specifiable` subclass object based on a spec. Args: @@ -155,7 +183,7 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]: return subclass kwargs = { - k: Specifiable._from_spec_helper(v, _run_init) + k: _specifiable_from_spec_helper(v, _run_init) for k, v in spec.config.items() } @@ -164,26 +192,6 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]: kwargs["_run_init"] = True return subclass(**kwargs) - @staticmethod - def _to_spec_helper(v): - if isinstance(v, Specifiable): - return v.to_spec() - - if isinstance(v, List): - return [Specifiable._to_spec_helper(e) for e in v] - - if inspect.isfunction(v): - if not hasattr(v, "spec_type"): - _register(v, inject_spec_type=False) - return Spec(type=_get_default_spec_type(v), config=None) - - if inspect.isclass(v): - if not hasattr(v, "spec_type"): - _register(v, inject_spec_type=False) - return Spec(type=_get_default_spec_type(v), config=None) - - return v - def to_spec(self) -> Spec: """Generate a spec from a `Specifiable` subclass object. @@ -195,14 +203,22 @@ def to_spec(self) -> Spec: f"'{type(self).__name__}' not registered as Specifiable. " f"Decorate ({type(self).__name__}) with @specifiable") - args = {k: self._to_spec_helper(v) for k, v in self.init_kwargs.items()} + args = { + k: _specifiable_to_spec_helper(v) + for k, v in self.init_kwargs.items() + } - return Spec(type=self.__class__.spec_type, config=args) + return Spec(type=self.spec_type(), config=args) def run_original_init(self) -> None: """Invoke the original __init__ method with original keyword arguments""" pass + @classmethod + def unspecifiable(cls) -> None: + """Resume the class structure prior to specifiable""" + pass + def _get_default_spec_type(cls): spec_type = cls.__name__ @@ -216,7 +232,7 @@ def _get_default_spec_type(cls): # Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE` -def _register(cls, spec_type=None, inject_spec_type=True) -> None: +def _register(cls: type, spec_type=None, inject_spec_type=True) -> None: assert spec_type is None or inject_spec_type, \ "need to inject spec_type to class if spec_type is not None" if spec_type is None: @@ -237,7 +253,8 @@ def _register(cls, spec_type=None, inject_spec_type=True) -> None: _KNOWN_SPECIFIABLE[subspace][spec_type] = cls if inject_spec_type: - cls.spec_type = spec_type + setattr(cls, cls.__name__ + '__spec_type', spec_type) + # cls.__spec_type = spec_type # Keep a copy of arguments that are used to call the `__init__` method when the @@ -250,13 +267,35 @@ def _get_init_kwargs(inst, init_method, *args, **kwargs): return params +@overload +def specifiable( + my_cls: None = None, + /, + *, + spec_type: Optional[str] = None, + on_demand_init: bool = True, + just_in_time_init: bool = True) -> Callable[[T], T]: + pass + + +@overload def specifiable( - my_cls=None, + my_cls: T, /, *, - spec_type=None, - on_demand_init=True, - just_in_time_init=True): + spec_type: Optional[str] = None, + on_demand_init: bool = True, + just_in_time_init: bool = True) -> T: + pass + + +def specifiable( + my_cls: Optional[T] = None, + /, + *, + spec_type: Optional[str] = None, + on_demand_init: bool = True, + just_in_time_init: bool = True) -> Union[T, Callable[[T], T]]: """A decorator that turns a class into a `Specifiable` subclass by implementing the `Specifiable` protocol. @@ -285,8 +324,8 @@ class Bar(): original `__init__` method will be called when the first time an attribute is accessed. """ - def _wrapper(cls): - def new_init(self: Specifiable, *args, **kwargs): + def _wrapper(cls: T) -> T: + def new_init(self, *args, **kwargs): self._initialized = False self._in_init = False @@ -358,20 +397,40 @@ def new_getattr(self, name): name) return self.__getattribute__(name) + def spec_type_func(cls): + return getattr(cls, spec_type_attr_name) + + def unspecifiable(cls): + delattr(cls, spec_type_attr_name) + cls.__init__ = original_init + if just_in_time_init: + delattr(cls, '__getattr__') + delattr(cls, 'spec_type') + delattr(cls, 'run_original_init') + delattr(cls, 'to_spec') + delattr(cls, 'from_spec') + delattr(cls, 'unspecifiable') + + spec_type_attr_name = cls.__name__ + "__spec_type" + + # the class is registered + if hasattr(cls, spec_type_attr_name): + return cls + # start of the function body of _wrapper _register(cls, spec_type) class_name = cls.__name__ - original_init = cls.__init__ - cls.__init__ = new_init + original_init = cls.__init__ # type: ignore[misc] + cls.__init__ = new_init # type: ignore[misc] if just_in_time_init: cls.__getattr__ = new_getattr + cls.spec_type = classmethod(spec_type_func) cls.run_original_init = run_original_init cls.to_spec = Specifiable.to_spec - cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) cls.from_spec = Specifiable.from_spec - cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) + cls.unspecifiable = classmethod(unspecifiable) return cls # end of the function body of _wrapper diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 4c1a7bdaf32a..4492cbbe4104 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -52,7 +52,7 @@ class B(): # apply the decorator function to an existing class A = specifiable(A) - self.assertEqual(A.spec_type, "A") + self.assertEqual(A.spec_type(), "A") self.assertTrue(isinstance(A(), Specifiable)) self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A) @@ -63,13 +63,10 @@ class B(): # Raise an error when re-registering spec_type with a different class self.assertRaises(ValueError, specifiable(spec_type='A'), B) - # apply the decorator function to an existing class with a different - # spec_type + # Applying the decorator function to an existing class with a different + # spec_type will have no effect. A = specifiable(spec_type="A_DUP")(A) - self.assertEqual(A.spec_type, "A_DUP") - self.assertTrue(isinstance(A(), Specifiable)) - self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) - self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A) + self.assertEqual(A.spec_type(), "A") def test_decorator_in_syntactic_sugar_form(self): # call decorator without parameters @@ -585,6 +582,49 @@ def apply(self, x, y): self.assertEqual(w_2.run_func_in_class(5, 3), 150) +class TestUncommonUsages(unittest.TestCase): + def test_double_specifiable(self): + @specifiable + @specifiable + class ZZ(): + def __init__(self, a): + self.a = a + + assert issubclass(ZZ, Specifiable) + c = ZZ("b") + c.run_original_init() + self.assertEqual(c.a, "b") + + def test_unspecifiable(self): + class YY(): + def __init__(self, x): + self.x = x + assert False + + YY = specifiable(YY) + assert issubclass(YY, Specifiable) + y = YY(1) + # __init__ is called (with assertion error raised) when attribute is first + # accessed + self.assertRaises(AssertionError, lambda: y.x) + + # unspecifiable YY + YY.unspecifiable() + # __init__ is called immediately + self.assertRaises(AssertionError, YY, 1) + self.assertFalse(hasattr(YY, 'run_original_init')) + self.assertFalse(hasattr(YY, 'spec_type')) + self.assertFalse(hasattr(YY, 'to_spec')) + self.assertFalse(hasattr(YY, 'from_spec')) + self.assertFalse(hasattr(YY, 'unspecifiable')) + + # make YY specifiable again + YY = specifiable(YY) + assert issubclass(YY, Specifiable) + y = YY(1) + self.assertRaises(AssertionError, lambda: y.x) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()