From 18f96560bdff8e8aaa7cd6e45a586222cc6d29c8 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Mon, 3 Feb 2025 23:05:17 -0500 Subject: [PATCH 01/11] Add base classes and specifiable protocol for anomaly detection. --- sdks/python/apache_beam/ml/anomaly/base.py | 114 +++++ .../apache_beam/ml/anomaly/base_test.py | 209 +++++++++ .../apache_beam/ml/anomaly/specifiable.py | 196 ++++++++ .../ml/anomaly/specifiable_test.py | 431 ++++++++++++++++++ 4 files changed, 950 insertions(+) create mode 100644 sdks/python/apache_beam/ml/anomaly/base.py create mode 100644 sdks/python/apache_beam/ml/anomaly/base_test.py create mode 100644 sdks/python/apache_beam/ml/anomaly/specifiable.py create mode 100644 sdks/python/apache_beam/ml/anomaly/specifiable_test.py diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py new file mode 100644 index 000000000000..fdbb88548d55 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Base classes for anomaly detection""" +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam + + +@dataclass(frozen=True) +class AnomalyPrediction(): + model_id: Optional[str] = None + score: Optional[float] = None + label: Optional[int] = None + threshold: Optional[float] = None + info: str = "" + agg_history: Optional[Iterable[AnomalyPrediction]] = None + + +@dataclass(frozen=True) +class AnomalyResult(): + example: beam.Row + prediction: AnomalyPrediction + + +class ThresholdFn(abc.ABC): + def __init__(self, normal_label: int = 0, outlier_label: int = 1): + self._normal_label = normal_label + self._outlier_label = outlier_label + + @property + @abc.abstractmethod + def is_stateful(self) -> bool: + raise NotImplementedError + + @property + @abc.abstractmethod + def threshold(self) -> Optional[float]: + raise NotImplementedError + + @abc.abstractmethod + def apply(self, score: Optional[float]) -> int: + raise NotImplementedError + + +class AggregationFn(abc.ABC): + @abc.abstractmethod + def apply( + self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction: + raise NotImplementedError + + +class AnomalyDetector(abc.ABC): + def __init__( + self, + model_id: Optional[str] = None, + features: Optional[Iterable[str]] = None, + target: Optional[str] = None, + threshold_criterion: Optional[ThresholdFn] = None, + **kwargs): + self._model_id = model_id if model_id is not None else getattr( + self, '_key', 'unknown') + self._features = features + self._target = target + self._threshold_criterion = threshold_criterion + + @abc.abstractmethod + def learn_one(self, x: beam.Row) -> None: + raise NotImplementedError + + @abc.abstractmethod + def score_one(self, x: beam.Row) -> float: + raise NotImplementedError + + +class EnsembleAnomalyDetector(AnomalyDetector): + def __init__( + self, + sub_detectors: Optional[List[AnomalyDetector]] = None, + aggregation_strategy: Optional[AggregationFn] = None, + **kwargs): + if "model_id" not in kwargs or kwargs["model_id"] is None: + kwargs["model_id"] = getattr(self, '_key', 'custom') + + super().__init__(**kwargs) + + self._aggregation_strategy = aggregation_strategy + self._sub_detectors = sub_detectors + + def learn_one(self, x: beam.Row) -> None: + raise NotImplementedError + + def score_one(self, x: beam.Row) -> float: + raise NotImplementedError diff --git a/sdks/python/apache_beam/ml/anomaly/base_test.py b/sdks/python/apache_beam/ml/anomaly/base_test.py new file mode 100644 index 000000000000..1894601624ae --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/base_test.py @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import logging +import unittest + +from apache_beam.ml.anomaly.base import AnomalyDetector +from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector +from apache_beam.ml.anomaly.base import ThresholdFn +from apache_beam.ml.anomaly.base import AggregationFn +from apache_beam.ml.anomaly.specifiable import Spec +from apache_beam.ml.anomaly.specifiable import Specifiable +from apache_beam.ml.anomaly.specifiable import specifiable + + +class TestAnomalyDetector(unittest.TestCase): + @specifiable(on_demand_init=False) + class DummyThreshold(ThresholdFn): + def __init__(self, my_threshold_arg=None): + ... + + def is_stateful(self): + return False + + def threshold(self): + ... + + def apply(self, x): + ... + + @specifiable(on_demand_init=False) + class Dummy(AnomalyDetector): + def __init__(self, my_arg=None, **kwargs): + self._my_arg = my_arg + super().__init__(**kwargs) + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value: "TestAnomalyDetector.Dummy") -> bool: + return self._my_arg == value._my_arg + + def test_unknown_detector(self): + self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) + + def test_model_id_on_known_detector(self): + a = self.Dummy( + my_arg="abc", + target="ABC", + threshold_criterion=(t1 := self.DummyThreshold(2))) + + self.assertEqual(a._model_id, "Dummy") + self.assertEqual(a._target, "ABC") + self.assertEqual(a._my_arg, "abc") + + assert isinstance(a, Specifiable) + self.assertEqual( + a._init_params, { + "my_arg": "abc", + "target": "ABC", + "threshold_criterion": t1, + }) + + b = self.Dummy( + my_arg="efg", + model_id="my_dummy", + target="EFG", + threshold_criterion=(t2 := self.DummyThreshold(2))) + self.assertEqual(b._model_id, "my_dummy") + self.assertEqual(b._target, "EFG") + self.assertEqual(b._my_arg, "efg") + + assert isinstance(b, Specifiable) + self.assertEqual( + b._init_params, + { + "model_id": "my_dummy", + "my_arg": "efg", + "target": "EFG", + "threshold_criterion": t2, + }) + + def test_from_and_to_specifiable(self): + obj = self.Dummy( + my_arg="hij", + model_id="my_dummy", + target="HIJ", + threshold_criterion=self.DummyThreshold(4)) + + assert isinstance(obj, Specifiable) + spec = obj.to_spec() + expected_spec = Spec( + type="Dummy", + config={ + "my_arg": "hij", + "model_id": "my_dummy", + "target": "HIJ", + "threshold_criterion": Spec( + type="DummyThreshold", config={"my_threshold_arg": 4}), + }) + self.assertEqual(spec, expected_spec) + + new_obj = Specifiable.from_spec(spec) + self.assertEqual(obj, new_obj) + + +class TestEnsembleAnomalyDetector(unittest.TestCase): + @specifiable(on_demand_init=False) + class DummyAggregation(AggregationFn): + def apply(self, x): + ... + + @specifiable(on_demand_init=False) + class DummyEnsemble(EnsembleAnomalyDetector): + def __init__(self, my_ensemble_arg=None, **kwargs): + super().__init__(**kwargs) + self._my_ensemble_arg = my_ensemble_arg + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__( + self, value: 'TestEnsembleAnomalyDetector.DummyEnsemble') -> bool: + return self._my_ensemble_arg == value._my_ensemble_arg + + @specifiable(on_demand_init=False) + class DummyWeakLearner(AnomalyDetector): + def __init__(self, my_arg=None, **kwargs): + super().__init__(**kwargs) + self._my_arg = my_arg + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__( + self, value: 'TestEnsembleAnomalyDetector.DummyWeakLearner') -> bool: + return self._my_arg == value._my_arg + + def test_model_id_on_known_detector(self): + a = self.DummyEnsemble() + self.assertEqual(a._model_id, "DummyEnsemble") + + b = self.DummyEnsemble(model_id="my_dummy_ensemble") + self.assertEqual(b._model_id, "my_dummy_ensemble") + + c = EnsembleAnomalyDetector() + self.assertEqual(c._model_id, "custom") + + d = EnsembleAnomalyDetector(model_id="my_dummy_ensemble_2") + self.assertEqual(d._model_id, "my_dummy_ensemble_2") + + def test_from_and_to_specifiable(self): + d1 = self.DummyWeakLearner(my_arg=1) + d2 = self.DummyWeakLearner(my_arg=2) + ensemble = self.DummyEnsemble( + my_ensemble_arg=123, + learners=[d1, d2], + aggregation_strategy=self.DummyAggregation()) + + expected_spec = Spec( + type="DummyEnsemble", + config={ + "my_ensemble_arg": 123, + "learners": [ + Spec(type="DummyWeakLearner", config={"my_arg": 1}), + Spec(type="DummyWeakLearner", config={"my_arg": 2}) + ], + "aggregation_strategy": Spec( + type="DummyAggregation", + config={}, + ), + }) + + assert isinstance(ensemble, Specifiable) + spec = ensemble.to_spec() + self.assertEqual(spec, expected_spec) + + new_ensemble = Specifiable.from_spec(spec) + self.assertEqual(ensemble, new_ensemble) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py new file mode 100644 index 000000000000..478ca93f941a --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -0,0 +1,196 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import dataclasses +import inspect +import logging +from typing import Any +from typing import ClassVar +from typing import List +from typing import Protocol +from typing import TypeVar +from typing import Type +from typing import runtime_checkable + +from typing_extensions import Self + +KNOWN_SPECIFIABLE = {} + +SpecT = TypeVar('SpecT', bound='Specifiable') + + +@dataclasses.dataclass(frozen=True) +class Spec(): + type: str + config: dict[str, Any] = dataclasses.field(default_factory=dict) + + +@runtime_checkable +class Specifiable(Protocol): + _key: ClassVar[str] + _init_params: dict[str, Any] + + @staticmethod + def _from_spec_helper(v): + if isinstance(v, Spec): + return Specifiable.from_spec(v) + + if isinstance(v, List): + return [Specifiable._from_spec_helper(e) for e in v] + + return v + + @classmethod + def from_spec(cls, spec: Spec) -> Self: + if spec.type is None: + raise ValueError(f"Spec type not found in {spec}") + + subclass: Type[Self] = KNOWN_SPECIFIABLE.get(spec.type, None) + if subclass is None: + raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") + + args = {k: Specifiable._from_spec_helper(v) for k, v in spec.config.items()} + + return subclass(**args) + + @staticmethod + def _to_spec_helper(v): + if isinstance(v, Specifiable): + return v.to_spec() + + if isinstance(v, List): + return [Specifiable._to_spec_helper(e) for e in v] + + return v + + def to_spec(self) -> Spec: + if getattr(type(self), '_key', None) is None: + raise ValueError( + f"'{type(self).__name__}' not registered as Specifiable. " + f"Decorate ({type(self).__name__}) with @specifiable") + + args = {k: self._to_spec_helper(v) for k, v in self._init_params.items()} + + return Spec(type=self.__class__._key, config=args) + + +def register(cls, key, error_if_exists) -> None: + if key is None: + key = cls.__name__ + + if key in KNOWN_SPECIFIABLE and error_if_exists: + raise ValueError(f"{key} is already registered for specifiable") + + KNOWN_SPECIFIABLE[key] = cls + + cls._key = key + + +def track_init_params(inst, init_method, *args, **kwargs): + params = dict( + zip(inspect.signature(init_method).parameters.keys(), (None, ) + args)) + del params['self'] + params.update(**kwargs) + inst._init_params = params + + +def specifiable( + my_cls=None, + /, + *, + key=None, + error_if_exists=True, + on_demand_init=True, + just_in_time_init=True): + + # register a specifiable, track init params for each instance, lazy init + def _wrapper(cls): + register(cls, key, error_if_exists) + + original_init = cls.__init__ + class_name = cls.__name__ + + def new_init(self, *args, **kwargs): + self._initialized = False + #self._nested_getattr = False + + if kwargs.get("_run_init", False): + run_init = True + del kwargs['_run_init'] + else: + run_init = False + + if '_init_params' not in self.__dict__: + track_init_params(self, original_init, *args, **kwargs) + + # If it is not a nested specifiable, we choose whether to skip original + # init call based on options. Otherwise, we always call original init + # for inner (parent/grandparent/etc) specifiable. + if (on_demand_init and not run_init) or \ + (not on_demand_init and just_in_time_init): + return + + logging.debug("call original %s.__init__ in new_init", class_name) + original_init(self, *args, **kwargs) + self._initialized = True + + def run_init(self): + original_init(self, **self._init_params) + + def new_getattr(self, name): + if name == '_nested_getattr' or \ + ('_nested_getattr' in self.__dict__ and self._nested_getattr): + #self._nested_getattr = False + delattr(self, "_nested_getattr") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") + + # set it before original init, in case getattr is called in original init + self._nested_getattr = True + + if not self._initialized and name != "__getstate__": + logging.debug("call original %s.__init__ in new_getattr", class_name) + original_init(self, **self._init_params) + self._initialized = True + + try: + logging.debug("call original %s.getattr in new_getattr", class_name) + ret = getattr(self, name) + finally: + # self._nested_getattr = False + delattr(self, "_nested_getattr") + return ret + + if just_in_time_init: + cls.__getattr__ = new_getattr + + cls.__init__ = new_init + cls._run_init = run_init + cls.to_spec = Specifiable.to_spec + cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) + cls.from_spec = classmethod(Specifiable.from_spec) # type: ignore + cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) + return cls + + if my_cls is None: + # support @specifiable(...) + return _wrapper + + # support @specifiable without arguments + return _wrapper(my_cls) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py new file mode 100644 index 000000000000..d176d1d34239 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -0,0 +1,431 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import dataclasses +import logging +from parameterized import parameterized +from typing import List +from typing import Optional +import unittest + +from apache_beam.ml.anomaly.specifiable import Spec +from apache_beam.ml.anomaly.specifiable import Specifiable +from apache_beam.ml.anomaly.specifiable import specifiable +from apache_beam.ml.anomaly.specifiable import KNOWN_SPECIFIABLE + + +class TestSpecifiable(unittest.TestCase): + def test_register_specifiable(self): + class MyClass(): + pass + + # class is not decorated/registered + self.assertRaises(AttributeError, lambda: MyClass().to_spec()) # type: ignore + + self.assertNotIn("MyKey", KNOWN_SPECIFIABLE) + + MyClass = specifiable(key="MyKey")(MyClass) + + self.assertIn("MyKey", KNOWN_SPECIFIABLE) + self.assertEqual(KNOWN_SPECIFIABLE["MyKey"], MyClass) + + # By default, an error is raised if the key is duplicated + self.assertRaises(ValueError, specifiable(key="MyKey"), MyClass) + + # But it is ok if a different key is used for the same class + _ = specifiable(key="MyOtherKey")(MyClass) + self.assertIn("MyOtherKey", KNOWN_SPECIFIABLE) + + # Or, use a parameter to suppress the error + specifiable(key="MyKey", error_if_exists=False)(MyClass) + + def test_decorator_key(self): + # use decorator without parameter + @specifiable + class MySecondClass(): + pass + + self.assertIn("MySecondClass", KNOWN_SPECIFIABLE) + self.assertEqual(KNOWN_SPECIFIABLE["MySecondClass"], MySecondClass) + self.assertTrue(isinstance(MySecondClass(), Specifiable)) + + # use decorator with key parameter + @specifiable(key="MyThirdKey") + class MyThirdClass(): + pass + + self.assertIn("MyThirdKey", KNOWN_SPECIFIABLE) + self.assertEqual(KNOWN_SPECIFIABLE["MyThirdKey"], MyThirdClass) + + def test_init_params_in_specifiable(self): + @specifiable + class MyClassWithInitParams(): + def __init__(self, arg_1, arg_2=2, arg_3="3", **kwargs): + pass + + a = MyClassWithInitParams(10, arg_3="30", arg_4=40) + assert isinstance(a, Specifiable) + self.assertEqual(a._init_params, {'arg_1': 10, 'arg_3': '30', 'arg_4': 40}) + + # inheritance of specifiable + @specifiable + class MyDerivedClassWithInitParams(MyClassWithInitParams): + def __init__(self, new_arg_1, new_arg_2=200, new_arg_3="300", **kwargs): + super().__init__(**kwargs) + + b = MyDerivedClassWithInitParams( + 1000, arg_1=11, arg_2=20, new_arg_2=2000, arg_4=4000) + assert isinstance(b, Specifiable) + self.assertEqual( + b._init_params, + { + 'new_arg_1': 1000, + 'arg_1': 11, + 'arg_2': 20, + 'new_arg_2': 2000, + 'arg_4': 4000 + }) + + # composite of specifiable + @specifiable + class MyCompositeClassWithInitParams(): + def __init__(self, my_class: Optional[MyClassWithInitParams] = None): + pass + + c = MyCompositeClassWithInitParams(a) + assert isinstance(c, Specifiable) + self.assertEqual(c._init_params, {'my_class': a}) + + def test_from_and_to_specifiable(self): + @specifiable(on_demand_init=False, just_in_time_init=False) + @dataclasses.dataclass + class Product(): + name: str + price: float + + @specifiable( + key="shopping_entry", on_demand_init=False, just_in_time_init=False) + class Entry(): + def __init__(self, product: Product, quantity: int = 1): + self._product = product + self._quantity = quantity + + def __eq__(self, value: 'Entry') -> bool: + return self._product == value._product and \ + self._quantity == value._quantity + + @specifiable( + key="shopping_cart", on_demand_init=False, just_in_time_init=False) + @dataclasses.dataclass + class ShoppingCart(): + user_id: str + entries: List[Entry] + + orange = Product("orange", 1.0) + + expected_orange_spec = Spec( + "Product", config={ + 'name': 'orange', 'price': 1.0 + }) + assert isinstance(orange, Specifiable) + self.assertEqual(orange.to_spec(), expected_orange_spec) + self.assertEqual(Specifiable.from_spec(expected_orange_spec), orange) + + entry_1 = Entry(product=orange) + + expected_entry_spec_1 = Spec( + "shopping_entry", config={ + 'product': expected_orange_spec, + }) + + assert isinstance(entry_1, Specifiable) + self.assertEqual(entry_1.to_spec(), expected_entry_spec_1) + self.assertEqual(Specifiable.from_spec(expected_entry_spec_1), entry_1) + + banana = Product("banana", 0.5) + expected_banana_spec = Spec( + "Product", config={ + 'name': 'banana', 'price': 0.5 + }) + entry_2 = Entry(product=banana, quantity=5) + expected_entry_spec_2 = Spec( + "shopping_entry", + config={ + 'product': expected_banana_spec, 'quantity': 5 + }) + + shopping_cart = ShoppingCart(user_id="test", entries=[entry_1, entry_2]) + expected_shopping_cart_spec = Spec( + "shopping_cart", + config={ + "user_id": "test", + "entries": [expected_entry_spec_1, expected_entry_spec_2] + }) + + assert isinstance(shopping_cart, Specifiable) + self.assertEqual(shopping_cart.to_spec(), expected_shopping_cart_spec) + self.assertEqual( + Specifiable.from_spec(expected_shopping_cart_spec), shopping_cart) + + def test_on_demand_init(self): + @specifiable(on_demand_init=True, just_in_time_init=False) + class FooOnDemand(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + FooOnDemand.counter += 1 + + foo = FooOnDemand(123) + self.assertEqual(FooOnDemand.counter, 0) + self.assertIn("_init_params", foo.__dict__) + self.assertEqual(foo.__dict__["_init_params"], {"arg": 123}) + + self.assertNotIn("my_arg", foo.__dict__) + self.assertRaises(AttributeError, getattr, foo, "my_arg") + self.assertRaises(AttributeError, lambda: foo.my_arg) + self.assertRaises(AttributeError, getattr, foo, "unknown_arg") + self.assertRaises(AttributeError, lambda: foo.unknown_arg) # type: ignore + self.assertEqual(FooOnDemand.counter, 0) + + foo_2 = FooOnDemand(456, _run_init=True) # type: ignore + self.assertEqual(FooOnDemand.counter, 1) + self.assertIn("_init_params", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 456}) + + self.assertIn("my_arg", foo_2.__dict__) + self.assertEqual(foo_2.my_arg, 4560) + self.assertEqual(FooOnDemand.counter, 1) + + def test_just_in_time_init(self): + @specifiable(on_demand_init=False, just_in_time_init=True) + class FooJustInTime(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + FooJustInTime.counter += 1 + + foo = FooJustInTime(321) + self.assertEqual(FooJustInTime.counter, 0) + self.assertIn("_init_params", foo.__dict__) + self.assertEqual(foo.__dict__["_init_params"], {"arg": 321}) + + self.assertNotIn("my_arg", foo.__dict__) # __init__ hasn't been called + self.assertEqual(FooJustInTime.counter, 0) + + # __init__ is called when trying to accessing an attribute + self.assertEqual(foo.my_arg, 3210) + self.assertEqual(FooJustInTime.counter, 1) + self.assertRaises(AttributeError, lambda: foo.unknown_arg) # type: ignore + self.assertEqual(FooJustInTime.counter, 1) + + def test_on_demand_and_just_in_time_init(self): + @specifiable(on_demand_init=True, just_in_time_init=True) + class FooOnDemandAndJustInTime(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + FooOnDemandAndJustInTime.counter += 1 + + foo = FooOnDemandAndJustInTime(987) + self.assertEqual(FooOnDemandAndJustInTime.counter, 0) + self.assertIn("_init_params", foo.__dict__) + self.assertEqual(foo.__dict__["_init_params"], {"arg": 987}) + self.assertNotIn("my_arg", foo.__dict__) + + self.assertEqual(FooOnDemandAndJustInTime.counter, 0) + # __init__ is called + self.assertEqual(foo.my_arg, 9870) + self.assertEqual(FooOnDemandAndJustInTime.counter, 1) + + # __init__ is called + foo_2 = FooOnDemandAndJustInTime(789, _run_init=True) # type: ignore + self.assertEqual(FooOnDemandAndJustInTime.counter, 2) + self.assertIn("_init_params", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 789}) + + self.assertEqual(FooOnDemandAndJustInTime.counter, 2) + # __init__ is NOT called + self.assertEqual(foo_2.my_arg, 7890) + self.assertEqual(FooOnDemandAndJustInTime.counter, 2) + + @specifiable(on_demand_init=True, just_in_time_init=True) + class FooForPickle(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + type(self).counter += 1 + + def test_on_pickle(self): + FooForPickle = TestSpecifiable.FooForPickle + + import dill + FooForPickle.counter = 0 + foo = FooForPickle(456) + self.assertEqual(FooForPickle.counter, 0) + new_foo = dill.loads(dill.dumps(foo)) + self.assertEqual(FooForPickle.counter, 0) + self.assertEqual(new_foo.__dict__, foo.__dict__) + self.assertEqual(foo.my_arg, 4560) + self.assertEqual(FooForPickle.counter, 1) + new_foo_2 = dill.loads(dill.dumps(foo)) + self.assertEqual(FooForPickle.counter, 1) + self.assertEqual(new_foo_2.__dict__, foo.__dict__) + + # Note that pickle does not support classes/functions nested in a function. + import pickle + FooForPickle.counter = 0 + foo = FooForPickle(456) + self.assertEqual(FooForPickle.counter, 0) + new_foo = pickle.loads(pickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 0) + self.assertEqual(new_foo.__dict__, foo.__dict__) + self.assertEqual(foo.my_arg, 4560) + self.assertEqual(FooForPickle.counter, 1) + new_foo_2 = pickle.loads(pickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 1) + self.assertEqual(new_foo_2.__dict__, foo.__dict__) + + import cloudpickle + FooForPickle.counter = 0 + foo = FooForPickle(456) + self.assertEqual(FooForPickle.counter, 0) + new_foo = cloudpickle.loads(cloudpickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 0) + self.assertEqual(new_foo.__dict__, foo.__dict__) + self.assertEqual(foo.my_arg, 4560) + self.assertEqual(FooForPickle.counter, 1) + new_foo_2 = cloudpickle.loads(cloudpickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 1) + self.assertEqual(new_foo_2.__dict__, foo.__dict__) + + +@specifiable +class Parent(): + counter = 0 + parent_class_var = 1000 + + def __init__(self, p): + self.parent_inst_var = p * 10 + Parent.counter += 1 + + +@specifiable +class Child_1(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + super().__init__(c) + self.child_inst_var = c + 1 + Child_1.counter += 1 + + +@specifiable +class Child_2(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + self.child_inst_var = c + 1 + super().__init__(c) + Child_2.counter += 1 + + +@specifiable +class Child_Error_1(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + self.child_inst_var += 1 # type: ignore + super().__init__(c) + Child_2.counter += 1 + + +@specifiable +class Child_Error_2(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + self.parent_inst_var += 1 # type: ignore + Child_2.counter += 1 + + +class TestNestedSpecifiable(unittest.TestCase): + @parameterized.expand([[Child_1, 0], [Child_2, 0], [Child_1, 1], [Child_2, 1], + [Child_1, 2], [Child_2, 2]]) + def test_nested_specifiable(self, Child, mode): + Parent.counter = 0 + Child.counter = 0 + child = Child(5) + + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child.counter, 0) + + # accessing class vars won't trigger __init__ + self.assertEqual(child.parent_class_var, 1000) + self.assertEqual(child.child_class_var, 2001) + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child.counter, 0) + + # accessing instance var will trigger __init__ + if mode == 0: + self.assertEqual(child.parent_inst_var, 50) + elif mode == 1: + self.assertEqual(child.child_inst_var, 6) + else: + self.assertRaises(AttributeError, lambda: child.unknown_var) + + self.assertEqual(Parent.counter, 1) + self.assertEqual(Child.counter, 1) + + # after initialization, it won't trigger __init__ again + self.assertEqual(child.parent_inst_var, 50) + self.assertEqual(child.child_inst_var, 6) + self.assertRaises(AttributeError, lambda: child.unknown_var) + + self.assertEqual(Parent.counter, 1) + self.assertEqual(Child.counter, 1) + + def test_error_in_child(self): + Parent.counter = 0 + child_1 = Child_Error_1(5) + + self.assertEqual(child_1.child_class_var, 2001) + + # error during child initialization + self.assertRaises(AttributeError, lambda: child_1.child_inst_var) # type: ignore + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child_1.counter, 0) + + child_2 = Child_Error_2(5) + self.assertEqual(child_2.child_class_var, 2001) + + # error during child initialization + self.assertRaises(AttributeError, lambda: child_2.parent_inst_var) # type: ignore + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child_2.counter, 0) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() From 3ea1d0ad3a11d6e96ff880dd3703d861294a296a Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 4 Feb 2025 00:22:22 -0500 Subject: [PATCH 02/11] Add subspaces to global specifiable map --- .../apache_beam/ml/anomaly/specifiable.py | 29 ++++++++++++++++--- .../ml/anomaly/specifiable_test.py | 16 +++++----- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 478ca93f941a..5db63721a33e 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -30,11 +30,27 @@ from typing_extensions import Self -KNOWN_SPECIFIABLE = {} +ACCEPTED_SPECIFIABLE_SUBSPACES = [ + "EnsembleAnomalyDetector", + "AnomalyDetector", + "ThresholdFn", + "AggregationFn", + "*" +] +KNOWN_SPECIFIABLE = {"*": {}} SpecT = TypeVar('SpecT', bound='Specifiable') +def get_subspace(cls): + subspace = "*" + for c in cls.mro(): + if c in ACCEPTED_SPECIFIABLE_SUBSPACES: + subspace = c.__name__ # type: ignore + break + return subspace + + @dataclasses.dataclass(frozen=True) class Spec(): type: str @@ -61,7 +77,8 @@ def from_spec(cls, spec: Spec) -> Self: if spec.type is None: raise ValueError(f"Spec type not found in {spec}") - subclass: Type[Self] = KNOWN_SPECIFIABLE.get(spec.type, None) + subspace = get_subspace(cls) + subclass: Type[Self] = KNOWN_SPECIFIABLE[subspace].get(spec.type, None) if subclass is None: raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") @@ -94,10 +111,14 @@ def register(cls, key, error_if_exists) -> None: if key is None: key = cls.__name__ - if key in KNOWN_SPECIFIABLE and error_if_exists: + subspace = get_subspace(cls) + if subspace in KNOWN_SPECIFIABLE and key in KNOWN_SPECIFIABLE[ + subspace] and error_if_exists: raise ValueError(f"{key} is already registered for specifiable") - KNOWN_SPECIFIABLE[key] = cls + if subspace not in KNOWN_SPECIFIABLE: + KNOWN_SPECIFIABLE[subspace] = {} + KNOWN_SPECIFIABLE[subspace][key] = cls cls._key = key diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index d176d1d34239..28e2632933fa 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -36,19 +36,19 @@ class MyClass(): # class is not decorated/registered self.assertRaises(AttributeError, lambda: MyClass().to_spec()) # type: ignore - self.assertNotIn("MyKey", KNOWN_SPECIFIABLE) + self.assertNotIn("MyKey", KNOWN_SPECIFIABLE["*"]) MyClass = specifiable(key="MyKey")(MyClass) - self.assertIn("MyKey", KNOWN_SPECIFIABLE) - self.assertEqual(KNOWN_SPECIFIABLE["MyKey"], MyClass) + self.assertIn("MyKey", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["MyKey"], MyClass) # By default, an error is raised if the key is duplicated self.assertRaises(ValueError, specifiable(key="MyKey"), MyClass) # But it is ok if a different key is used for the same class _ = specifiable(key="MyOtherKey")(MyClass) - self.assertIn("MyOtherKey", KNOWN_SPECIFIABLE) + self.assertIn("MyOtherKey", KNOWN_SPECIFIABLE["*"]) # Or, use a parameter to suppress the error specifiable(key="MyKey", error_if_exists=False)(MyClass) @@ -59,8 +59,8 @@ def test_decorator_key(self): class MySecondClass(): pass - self.assertIn("MySecondClass", KNOWN_SPECIFIABLE) - self.assertEqual(KNOWN_SPECIFIABLE["MySecondClass"], MySecondClass) + self.assertIn("MySecondClass", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["MySecondClass"], MySecondClass) self.assertTrue(isinstance(MySecondClass(), Specifiable)) # use decorator with key parameter @@ -68,8 +68,8 @@ class MySecondClass(): class MyThirdClass(): pass - self.assertIn("MyThirdKey", KNOWN_SPECIFIABLE) - self.assertEqual(KNOWN_SPECIFIABLE["MyThirdKey"], MyThirdClass) + self.assertIn("MyThirdKey", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["MyThirdKey"], MyThirdClass) def test_init_params_in_specifiable(self): @specifiable From 0e6b9a7cf6fc728ae333cb6217b2198dea235c99 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 4 Feb 2025 00:22:43 -0500 Subject: [PATCH 03/11] Add __init__.py --- sdks/python/apache_beam/ml/anomaly/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 sdks/python/apache_beam/ml/anomaly/__init__.py diff --git a/sdks/python/apache_beam/ml/anomaly/__init__.py b/sdks/python/apache_beam/ml/anomaly/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# From 4023c80dada85d97aa05d74de3d872352a518b94 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 4 Feb 2025 00:35:15 -0500 Subject: [PATCH 04/11] Fix lints --- .../apache_beam/ml/anomaly/base_test.py | 19 ++++++------ .../apache_beam/ml/anomaly/specifiable.py | 6 ++-- .../ml/anomaly/specifiable_test.py | 30 ++++++++++--------- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/base_test.py b/sdks/python/apache_beam/ml/anomaly/base_test.py index 1894601624ae..f425525a250a 100644 --- a/sdks/python/apache_beam/ml/anomaly/base_test.py +++ b/sdks/python/apache_beam/ml/anomaly/base_test.py @@ -20,10 +20,10 @@ import logging import unittest +from apache_beam.ml.anomaly.base import AggregationFn from apache_beam.ml.anomaly.base import AnomalyDetector from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector from apache_beam.ml.anomaly.base import ThresholdFn -from apache_beam.ml.anomaly.base import AggregationFn from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.specifiable import specifiable @@ -56,8 +56,9 @@ def learn_one(self): def score_one(self): ... - def __eq__(self, value: "TestAnomalyDetector.Dummy") -> bool: - return self._my_arg == value._my_arg + def __eq__(self, value) -> bool: + return isinstance(value, TestAnomalyDetector.Dummy) and \ + self._my_arg == value._my_arg def test_unknown_detector(self): self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) @@ -141,9 +142,9 @@ def learn_one(self): def score_one(self): ... - def __eq__( - self, value: 'TestEnsembleAnomalyDetector.DummyEnsemble') -> bool: - return self._my_ensemble_arg == value._my_ensemble_arg + def __eq__(self, value) -> bool: + return isinstance(value, TestEnsembleAnomalyDetector.DummyEnsemble) and \ + self._my_ensemble_arg == value._my_ensemble_arg @specifiable(on_demand_init=False) class DummyWeakLearner(AnomalyDetector): @@ -157,9 +158,9 @@ def learn_one(self): def score_one(self): ... - def __eq__( - self, value: 'TestEnsembleAnomalyDetector.DummyWeakLearner') -> bool: - return self._my_arg == value._my_arg + def __eq__(self, value) -> bool: + return isinstance(value, TestEnsembleAnomalyDetector.DummyWeakLearner) \ + and self._my_arg == value._my_arg def test_model_id_on_known_detector(self): a = self.DummyEnsemble() diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 5db63721a33e..474e509fee97 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -24,8 +24,8 @@ from typing import ClassVar from typing import List from typing import Protocol -from typing import TypeVar from typing import Type +from typing import TypeVar from typing import runtime_checkable from typing_extensions import Self @@ -46,7 +46,7 @@ def get_subspace(cls): subspace = "*" for c in cls.mro(): if c in ACCEPTED_SPECIFIABLE_SUBSPACES: - subspace = c.__name__ # type: ignore + subspace = c.__name__ break return subspace @@ -205,7 +205,7 @@ def new_getattr(self, name): cls._run_init = run_init cls.to_spec = Specifiable.to_spec cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) - cls.from_spec = classmethod(Specifiable.from_spec) # type: ignore + cls.from_spec = classmethod(Specifiable.from_spec) cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) return cls diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 28e2632933fa..71910b03ed5b 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -17,15 +17,16 @@ import dataclasses import logging -from parameterized import parameterized +import unittest from typing import List from typing import Optional -import unittest +from parameterized import parameterized + +from apache_beam.ml.anomaly.specifiable import KNOWN_SPECIFIABLE from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.specifiable import specifiable -from apache_beam.ml.anomaly.specifiable import KNOWN_SPECIFIABLE class TestSpecifiable(unittest.TestCase): @@ -34,7 +35,7 @@ class MyClass(): pass # class is not decorated/registered - self.assertRaises(AttributeError, lambda: MyClass().to_spec()) # type: ignore + self.assertRaises(AttributeError, lambda: MyClass().to_spec()) self.assertNotIn("MyKey", KNOWN_SPECIFIABLE["*"]) @@ -124,8 +125,9 @@ def __init__(self, product: Product, quantity: int = 1): self._product = product self._quantity = quantity - def __eq__(self, value: 'Entry') -> bool: - return self._product == value._product and \ + def __eq__(self, value) -> bool: + return isinstance(value, Entry) and \ + self._product == value._product and \ self._quantity == value._quantity @specifiable( @@ -199,10 +201,10 @@ def __init__(self, arg): self.assertRaises(AttributeError, getattr, foo, "my_arg") self.assertRaises(AttributeError, lambda: foo.my_arg) self.assertRaises(AttributeError, getattr, foo, "unknown_arg") - self.assertRaises(AttributeError, lambda: foo.unknown_arg) # type: ignore + self.assertRaises(AttributeError, lambda: foo.unknown_arg) self.assertEqual(FooOnDemand.counter, 0) - foo_2 = FooOnDemand(456, _run_init=True) # type: ignore + foo_2 = FooOnDemand(456, _run_init=True) self.assertEqual(FooOnDemand.counter, 1) self.assertIn("_init_params", foo_2.__dict__) self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 456}) @@ -231,7 +233,7 @@ def __init__(self, arg): # __init__ is called when trying to accessing an attribute self.assertEqual(foo.my_arg, 3210) self.assertEqual(FooJustInTime.counter, 1) - self.assertRaises(AttributeError, lambda: foo.unknown_arg) # type: ignore + self.assertRaises(AttributeError, lambda: foo.unknown_arg) self.assertEqual(FooJustInTime.counter, 1) def test_on_demand_and_just_in_time_init(self): @@ -255,7 +257,7 @@ def __init__(self, arg): self.assertEqual(FooOnDemandAndJustInTime.counter, 1) # __init__ is called - foo_2 = FooOnDemandAndJustInTime(789, _run_init=True) # type: ignore + foo_2 = FooOnDemandAndJustInTime(789, _run_init=True) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) self.assertIn("_init_params", foo_2.__dict__) self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 789}) @@ -355,7 +357,7 @@ class Child_Error_1(Parent): child_class_var = 2001 def __init__(self, c): - self.child_inst_var += 1 # type: ignore + self.child_inst_var += 1 super().__init__(c) Child_2.counter += 1 @@ -366,7 +368,7 @@ class Child_Error_2(Parent): child_class_var = 2001 def __init__(self, c): - self.parent_inst_var += 1 # type: ignore + self.parent_inst_var += 1 Child_2.counter += 1 @@ -413,7 +415,7 @@ def test_error_in_child(self): self.assertEqual(child_1.child_class_var, 2001) # error during child initialization - self.assertRaises(AttributeError, lambda: child_1.child_inst_var) # type: ignore + self.assertRaises(AttributeError, lambda: child_1.child_inst_var) self.assertEqual(Parent.counter, 0) self.assertEqual(Child_1.counter, 0) @@ -421,7 +423,7 @@ def test_error_in_child(self): self.assertEqual(child_2.child_class_var, 2001) # error during child initialization - self.assertRaises(AttributeError, lambda: child_2.parent_inst_var) # type: ignore + self.assertRaises(AttributeError, lambda: child_2.parent_inst_var) self.assertEqual(Parent.counter, 0) self.assertEqual(Child_2.counter, 0) From 24c77e947521773117d9a3706829936b2ce8c1ab Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 4 Feb 2025 10:27:08 -0500 Subject: [PATCH 05/11] Fix get_subspace when calling from from_spec --- .../apache_beam/ml/anomaly/specifiable.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 474e509fee97..40b99691bbfc 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -42,13 +42,19 @@ SpecT = TypeVar('SpecT', bound='Specifiable') -def get_subspace(cls): - subspace = "*" - for c in cls.mro(): - if c in ACCEPTED_SPECIFIABLE_SUBSPACES: - subspace = c.__name__ - break - return subspace +def get_subspace(cls, type=None): + if type is None: + subspace = "*" + for c in cls.mro(): + if c.__name__ in ACCEPTED_SPECIFIABLE_SUBSPACES: + subspace = c.__name__ + break + return subspace + else: + for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: + if subspace in KNOWN_SPECIFIABLE and type in KNOWN_SPECIFIABLE[subspace]: + return subspace + raise ValueError(f"subspace for {cls.__name__} not found.") @dataclasses.dataclass(frozen=True) @@ -77,7 +83,7 @@ def from_spec(cls, spec: Spec) -> Self: if spec.type is None: raise ValueError(f"Spec type not found in {spec}") - subspace = get_subspace(cls) + subspace = get_subspace(cls, spec.type) subclass: Type[Self] = KNOWN_SPECIFIABLE[subspace].get(spec.type, None) if subclass is None: raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") From e4a32c2e97025e0328166530bbecfd31e771d644 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Thu, 6 Feb 2025 00:01:10 -0500 Subject: [PATCH 06/11] Refactor code, add tests and add docstrings. --- sdks/python/apache_beam/ml/anomaly/base.py | 103 ++++++- .../apache_beam/ml/anomaly/base_test.py | 268 +++++++++-------- .../apache_beam/ml/anomaly/specifiable.py | 269 ++++++++++++------ .../ml/anomaly/specifiable_test.py | 204 ++++++++----- 4 files changed, 561 insertions(+), 283 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py index fdbb88548d55..dfe29ee55ee9 100644 --- a/sdks/python/apache_beam/ml/anomaly/base.py +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -15,7 +15,9 @@ # limitations under the License. # -"""Base classes for anomaly detection""" +""" +Base classes for anomaly detection +""" from __future__ import annotations import abc @@ -26,24 +28,51 @@ import apache_beam as beam +__all__ = [ + "AnomalyPrediction", + "AnomalyResult", + "ThresholdFn", + "AggregationFn", + "AnomalyDetector", + "EnsembleAnomalyDetector" +] + @dataclass(frozen=True) class AnomalyPrediction(): + """A dataclass for anomaly detection predictions.""" + #: The ID of detector (model) that generates the prediction. model_id: Optional[str] = None + #: The outlier score resulting from applying the detector to the input data. score: Optional[float] = None + #: The outlier label (normal or outlier) derived from the outlier score. label: Optional[int] = None + #: The threshold used to determine the label. threshold: Optional[float] = None + #: Additional information about the prediction. info: str = "" + #: If enabled, a list of `AnomalyPrediction` objects used to derive the + #: aggregated prediction. agg_history: Optional[Iterable[AnomalyPrediction]] = None @dataclass(frozen=True) class AnomalyResult(): + """A dataclass for the anomaly detection results""" + #: The original input data. example: beam.Row + #: The `AnomalyPrediction` object containing the prediction. prediction: AnomalyPrediction class ThresholdFn(abc.ABC): + """An abstract base class for threshold functions. + + Args: + normal_label: The integer label used to identify normal data. Defaults to 0. + outlier_label: The integer label used to identify outlier data. Defaults to + 1. + """ def __init__(self, normal_label: int = 0, outlier_label: int = 1): self._normal_label = normal_label self._outlier_label = outlier_label @@ -51,26 +80,59 @@ def __init__(self, normal_label: int = 0, outlier_label: int = 1): @property @abc.abstractmethod def is_stateful(self) -> bool: + """Indicates whether the threshold function is stateful or not.""" raise NotImplementedError @property @abc.abstractmethod def threshold(self) -> Optional[float]: + """Retrieves the current threshold value, or None if not set.""" raise NotImplementedError @abc.abstractmethod def apply(self, score: Optional[float]) -> int: + """Applies the threshold function to a given score to classify it as + normal or outlier. + + Args: + score: The outlier score generated from the detector (model). + + Returns: + The label assigned to the score, either `self._normal_label` + or `self._outlier_label` + """ raise NotImplementedError class AggregationFn(abc.ABC): + """An abstract base class for aggregation functions.""" @abc.abstractmethod def apply( self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction: + """Applies the aggregation function to an iterable of predictions, either on + their outlier scores or labels. + + Args: + predictions: An Iterable of `AnomalyPrediction` objects to aggregate. + + Returns: + An `AnomalyPrediction` object containing the aggregated result. + """ raise NotImplementedError class AnomalyDetector(abc.ABC): + """An abstract base class for anomaly detectors. + + Args: + model_id: The ID of detector (model). Defaults to the value of the + `spec_type` attribute, or 'unknown' if not set. + features: An Iterable of strings representing the names of the input + features in the `beam.Row` + target: The name of the target field in the `beam.Row`. + threshold_criterion: An optional `ThresholdFn` to apply to the outlier score + and yield a label. + """ def __init__( self, model_id: Optional[str] = None, @@ -79,28 +141,52 @@ def __init__( threshold_criterion: Optional[ThresholdFn] = None, **kwargs): self._model_id = model_id if model_id is not None else getattr( - self, '_key', 'unknown') + self, 'spec_type', 'unknown') self._features = features self._target = target self._threshold_criterion = threshold_criterion @abc.abstractmethod def learn_one(self, x: beam.Row) -> None: + """Trains the detector on a single data instance. + + Args: + x: A `beam.Row` representing the data instance. + """ raise NotImplementedError @abc.abstractmethod def score_one(self, x: beam.Row) -> float: + """Scores a single data instance for anomalies. + + Args: + x: A `beam.Row` representing the data instance. + + Returns: + The outlier score as a float. + """ raise NotImplementedError class EnsembleAnomalyDetector(AnomalyDetector): + """An abstract base class for an ensemble of anomaly (sub-)detectors. + + Args: + sub_detectors: A List of `AnomalyDetector` used in this ensemble model. + aggregation_strategy: An optional `AggregationFn` to apply to the + predictions from all sub-detectors and yield an aggregated result. + model_id: Inherited from `AnomalyDetector`. + features: Inherited from `AnomalyDetector`. + target: Inherited from `AnomalyDetector`. + threshold_criterion: Inherited from `AnomalyDetector`. + """ def __init__( self, sub_detectors: Optional[List[AnomalyDetector]] = None, aggregation_strategy: Optional[AggregationFn] = None, **kwargs): if "model_id" not in kwargs or kwargs["model_id"] is None: - kwargs["model_id"] = getattr(self, '_key', 'custom') + kwargs["model_id"] = getattr(self, 'spec_type', 'custom') super().__init__(**kwargs) @@ -108,7 +194,18 @@ def __init__( self._sub_detectors = sub_detectors def learn_one(self, x: beam.Row) -> None: + """Inherited from `AnomalyDetector.learn_one`. + + This method is never called during ensemble detector training. The training + process is done on each sub-detector independently and in parallel. + """ raise NotImplementedError def score_one(self, x: beam.Row) -> float: + """Inherited from `AnomalyDetector.score_one`. + + This method is never called during ensemble detector scoring. The scoring + process is done on sub-detector independently and in parallel, and then + the results are aggregated in the pipeline. + """ raise NotImplementedError diff --git a/sdks/python/apache_beam/ml/anomaly/base_test.py b/sdks/python/apache_beam/ml/anomaly/base_test.py index f425525a250a..715d5128ee1e 100644 --- a/sdks/python/apache_beam/ml/anomaly/base_test.py +++ b/sdks/python/apache_beam/ml/anomaly/base_test.py @@ -20,6 +20,8 @@ import logging import unittest +from parameterized import parameterized + from apache_beam.ml.anomaly.base import AggregationFn from apache_beam.ml.anomaly.base import AnomalyDetector from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector @@ -30,69 +32,85 @@ class TestAnomalyDetector(unittest.TestCase): - @specifiable(on_demand_init=False) - class DummyThreshold(ThresholdFn): - def __init__(self, my_threshold_arg=None): - ... - - def is_stateful(self): - return False - - def threshold(self): - ... - - def apply(self, x): - ... - - @specifiable(on_demand_init=False) - class Dummy(AnomalyDetector): - def __init__(self, my_arg=None, **kwargs): - self._my_arg = my_arg - super().__init__(**kwargs) - - def learn_one(self): - ... - - def score_one(self): - ... - - def __eq__(self, value) -> bool: - return isinstance(value, TestAnomalyDetector.Dummy) and \ - self._my_arg == value._my_arg - - def test_unknown_detector(self): - self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) - - def test_model_id_on_known_detector(self): - a = self.Dummy( + @parameterized.expand([(False, False), (True, False), (False, True), + (True, True)]) + def test_model_id_and_spec(self, on_demand_init, just_in_time_init): + @specifiable( + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) + class DummyThreshold(ThresholdFn): + def __init__(self, my_threshold_arg=None): + ... + + def is_stateful(self): + return False + + def threshold(self): + ... + + def apply(self, x): + ... + + @specifiable( + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) + class Dummy(AnomalyDetector): + def __init__(self, my_arg=None, **kwargs): + self._my_arg = my_arg + super().__init__(**kwargs) + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value) -> bool: + return isinstance(value, Dummy) and \ + self._my_arg == value._my_arg + + a = Dummy( my_arg="abc", target="ABC", - threshold_criterion=(t1 := self.DummyThreshold(2))) - - self.assertEqual(a._model_id, "Dummy") - self.assertEqual(a._target, "ABC") - self.assertEqual(a._my_arg, "abc") + threshold_criterion=(t1 := DummyThreshold(2))) + + # The class attributes can only be accessed when + # (1) on_demand_init == False, just_in_time_init == False + # In this case, the true __init__ is called immediately during object + # initialization + # (2) just_in_time_init == True + # In this case, regardless what on_demand_init is, the true __init__ + # is called when we first access any class attribute (delay init). + if just_in_time_init or not on_demand_init: + self.assertEqual(a._model_id, "Dummy") + self.assertEqual(a._target, "ABC") + self.assertEqual(a._my_arg, "abc") assert isinstance(a, Specifiable) self.assertEqual( - a._init_params, { + a.init_kwargs, { "my_arg": "abc", "target": "ABC", "threshold_criterion": t1, }) - b = self.Dummy( + b = Dummy( my_arg="efg", model_id="my_dummy", target="EFG", - threshold_criterion=(t2 := self.DummyThreshold(2))) - self.assertEqual(b._model_id, "my_dummy") - self.assertEqual(b._target, "EFG") - self.assertEqual(b._my_arg, "efg") + threshold_criterion=(t2 := DummyThreshold(3))) + + # See the comment above for more details. + if just_in_time_init or not on_demand_init: + self.assertEqual(b._model_id, "my_dummy") + self.assertEqual(b._target, "EFG") + self.assertEqual(b._my_arg, "efg") assert isinstance(b, Specifiable) self.assertEqual( - b._init_params, + b.init_kwargs, { "model_id": "my_dummy", "my_arg": "efg", @@ -100,88 +118,97 @@ def test_model_id_on_known_detector(self): "threshold_criterion": t2, }) - def test_from_and_to_specifiable(self): - obj = self.Dummy( - my_arg="hij", - model_id="my_dummy", - target="HIJ", - threshold_criterion=self.DummyThreshold(4)) - - assert isinstance(obj, Specifiable) - spec = obj.to_spec() + spec = b.to_spec() expected_spec = Spec( type="Dummy", config={ - "my_arg": "hij", + "my_arg": "efg", "model_id": "my_dummy", - "target": "HIJ", + "target": "EFG", "threshold_criterion": Spec( - type="DummyThreshold", config={"my_threshold_arg": 4}), + type="DummyThreshold", config={"my_threshold_arg": 3}), }) self.assertEqual(spec, expected_spec) - new_obj = Specifiable.from_spec(spec) - self.assertEqual(obj, new_obj) + b_dup = Specifiable.from_spec(spec) + # We need to manually call the original __init__ function in one scenario. + # See the comment above for more details. + if on_demand_init and not just_in_time_init: + b.run_original_init() + self.assertEqual(b, b_dup) -class TestEnsembleAnomalyDetector(unittest.TestCase): - @specifiable(on_demand_init=False) - class DummyAggregation(AggregationFn): - def apply(self, x): - ... - - @specifiable(on_demand_init=False) - class DummyEnsemble(EnsembleAnomalyDetector): - def __init__(self, my_ensemble_arg=None, **kwargs): - super().__init__(**kwargs) - self._my_ensemble_arg = my_ensemble_arg - - def learn_one(self): - ... - - def score_one(self): - ... - - def __eq__(self, value) -> bool: - return isinstance(value, TestEnsembleAnomalyDetector.DummyEnsemble) and \ - self._my_ensemble_arg == value._my_ensemble_arg - - @specifiable(on_demand_init=False) - class DummyWeakLearner(AnomalyDetector): - def __init__(self, my_arg=None, **kwargs): - super().__init__(**kwargs) - self._my_arg = my_arg - - def learn_one(self): - ... - - def score_one(self): - ... - def __eq__(self, value) -> bool: - return isinstance(value, TestEnsembleAnomalyDetector.DummyWeakLearner) \ - and self._my_arg == value._my_arg - - def test_model_id_on_known_detector(self): - a = self.DummyEnsemble() - self.assertEqual(a._model_id, "DummyEnsemble") - - b = self.DummyEnsemble(model_id="my_dummy_ensemble") - self.assertEqual(b._model_id, "my_dummy_ensemble") - - c = EnsembleAnomalyDetector() - self.assertEqual(c._model_id, "custom") - - d = EnsembleAnomalyDetector(model_id="my_dummy_ensemble_2") - self.assertEqual(d._model_id, "my_dummy_ensemble_2") - - def test_from_and_to_specifiable(self): - d1 = self.DummyWeakLearner(my_arg=1) - d2 = self.DummyWeakLearner(my_arg=2) - ensemble = self.DummyEnsemble( +class TestEnsembleAnomalyDetector(unittest.TestCase): + @parameterized.expand([(False, False), (True, False), (False, True), + (True, True)]) + def test_model_id_and_spec(self, on_demand_init, just_in_time_init): + @specifiable( + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) + class DummyAggregation(AggregationFn): + def apply(self, x): + ... + + @specifiable( + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) + class DummyEnsemble(EnsembleAnomalyDetector): + def __init__(self, my_ensemble_arg=None, **kwargs): + super().__init__(**kwargs) + self._my_ensemble_arg = my_ensemble_arg + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value) -> bool: + return isinstance(value, DummyEnsemble) and \ + self._my_ensemble_arg == value._my_ensemble_arg + + @specifiable( + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) + class DummyWeakLearner(AnomalyDetector): + def __init__(self, my_arg=None, **kwargs): + super().__init__(**kwargs) + self._my_arg = my_arg + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value) -> bool: + return isinstance(value, DummyWeakLearner) \ + and self._my_arg == value._my_arg + + # See the comment in TestAnomalyDetector for more details. + if just_in_time_init or not on_demand_init: + a = DummyEnsemble() + self.assertEqual(a._model_id, "DummyEnsemble") + + b = DummyEnsemble(model_id="my_dummy_ensemble") + self.assertEqual(b._model_id, "my_dummy_ensemble") + + c = EnsembleAnomalyDetector() + self.assertEqual(c._model_id, "custom") + + d = EnsembleAnomalyDetector(model_id="my_dummy_ensemble_2") + self.assertEqual(d._model_id, "my_dummy_ensemble_2") + + d1 = DummyWeakLearner(my_arg=1) + d2 = DummyWeakLearner(my_arg=2) + ensemble = DummyEnsemble( my_ensemble_arg=123, learners=[d1, d2], - aggregation_strategy=self.DummyAggregation()) + aggregation_strategy=DummyAggregation()) expected_spec = Spec( type="DummyEnsemble", @@ -201,8 +228,13 @@ def test_from_and_to_specifiable(self): spec = ensemble.to_spec() self.assertEqual(spec, expected_spec) - new_ensemble = Specifiable.from_spec(spec) - self.assertEqual(ensemble, new_ensemble) + ensemble_dup = Specifiable.from_spec(spec) + + # See the comment in TestAnomalyDetector for more details. + if on_demand_init and not just_in_time_init: + ensemble.run_original_init() + + self.assertEqual(ensemble, ensemble_dup) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 40b99691bbfc..7765f8998ada 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -15,6 +15,10 @@ # limitations under the License. # +""" +A module that provides utilities to turn a class into a Specifiable subclass. +""" + from __future__ import annotations import dataclasses @@ -30,6 +34,8 @@ from typing_extensions import Self +__all__ = ["KNOWN_SPECIFIABLE", "Spec", "Specifiable", "specifiable"] + ACCEPTED_SPECIFIABLE_SUBSPACES = [ "EnsembleAnomalyDetector", "AnomalyDetector", @@ -37,60 +43,110 @@ "AggregationFn", "*" ] + +#: A nested dictionary for efficient lookup of Specifiable subclasses. +#: Structure: KNOWN_SPECIFIABLE[subspace][spec_type], where "subspace" is one of +#: the accepted subspaces that the class belongs to and "spec_type" is the class +#: name by default. Users can also specify a different value for "spec_type" +#: when applying the `specifiable` decorator to an existing class. KNOWN_SPECIFIABLE = {"*": {}} SpecT = TypeVar('SpecT', bound='Specifiable') -def get_subspace(cls, type=None): - if type is None: - subspace = "*" - for c in cls.mro(): - if c.__name__ in ACCEPTED_SPECIFIABLE_SUBSPACES: - subspace = c.__name__ - break - return subspace - else: - for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: - if subspace in KNOWN_SPECIFIABLE and type in KNOWN_SPECIFIABLE[subspace]: - return subspace +def _class_to_subspace(cls: Type, default="*") -> str: + """ + Search the class hierarchy to find the subspace: the closest ancestor class in + the class's method resolution order (MRO) whose name is found in the accepted + subspace list. This is usually called when registering a new Specifiable + class. + """ + for c in cls.mro(): + # + if c.__name__ in ACCEPTED_SPECIFIABLE_SUBSPACES: + return c.__name__ + + if default is None: raise ValueError(f"subspace for {cls.__name__} not found.") + return default + + +def _spec_type_to_subspace(type: str, default="*") -> str: + """ + Look for the subspace for a spec type. This is usually called to retrieve + the subspace of a registered Specifiable class. + """ + for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: + if type in KNOWN_SPECIFIABLE.get(subspace, {}): + return subspace + + if default is None: + raise ValueError(f"subspace for {type} not found.") + + return default + @dataclasses.dataclass(frozen=True) class Spec(): + """ + Dataclass for storing specifications of specifiable objects. + Objects can be initialized using the data in their corresponding spec. + The `type` field indicates the concrete Specifiable class, while + """ + #: A string indicating the concrete Specifiable class type: str + #: A dictionary of keyword arguments for the `__init__` method of the class. config: dict[str, Any] = dataclasses.field(default_factory=dict) @runtime_checkable class Specifiable(Protocol): - _key: ClassVar[str] - _init_params: dict[str, Any] + """Protocol that a Specifiable subclass needs to implement. + + Attributes: + spec_type: The value of the `type` field in the object's Spec for this + class. + init_kwargs: The raw keyword arguments passed to `__init__` during object + initialization. + """ + spec_type: ClassVar[str] + init_kwargs: dict[str, Any] + # a boolean to tell whether the original __init__ is called + _initialized: bool + # a boolean used by new_getattr to tell whether it is in an __init__ call + _in_init: bool @staticmethod - def _from_spec_helper(v): + def _from_spec_helper(v, _run_init): if isinstance(v, Spec): - return Specifiable.from_spec(v) + return Specifiable.from_spec(v, _run_init) if isinstance(v, List): - return [Specifiable._from_spec_helper(e) for e in v] + return [Specifiable._from_spec_helper(e, _run_init) for e in v] return v @classmethod - def from_spec(cls, spec: Spec) -> Self: + def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: + """Generate a Specifiable subclass object based on a spec.""" if spec.type is None: raise ValueError(f"Spec type not found in {spec}") - subspace = get_subspace(cls, spec.type) + subspace = _spec_type_to_subspace(spec.type) subclass: Type[Self] = KNOWN_SPECIFIABLE[subspace].get(spec.type, None) if subclass is None: raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") - args = {k: Specifiable._from_spec_helper(v) for k, v in spec.config.items()} + kwargs = { + k: Specifiable._from_spec_helper(v, _run_init) + for k, + v in spec.config.items() + } - return subclass(**args) + if _run_init: + kwargs["_run_init"] = True + return subclass(**kwargs) @staticmethod def _to_spec_helper(v): @@ -103,121 +159,164 @@ def _to_spec_helper(v): return v def to_spec(self) -> Spec: - if getattr(type(self), '_key', None) is None: + """ + Generate a spec from a Specifiable subclass object. + """ + if getattr(type(self), 'spec_type', None) is None: raise ValueError( f"'{type(self).__name__}' not registered as Specifiable. " f"Decorate ({type(self).__name__}) with @specifiable") - args = {k: self._to_spec_helper(v) for k, v in self._init_params.items()} + args = {k: self._to_spec_helper(v) for k, v in self.init_kwargs.items()} - return Spec(type=self.__class__._key, config=args) + return Spec(type=self.__class__.spec_type, config=args) -def register(cls, key, error_if_exists) -> None: - if key is None: - key = cls.__name__ +# Register a Specifiable subclass in KNOWN_SPECIFIABLE +def _register(cls, spec_type=None, error_if_exists=True) -> None: + if spec_type is None: + # By default, spec type is the class name. Users can override this with + # other unique identifier. + spec_type = cls.__name__ - subspace = get_subspace(cls) - if subspace in KNOWN_SPECIFIABLE and key in KNOWN_SPECIFIABLE[ - subspace] and error_if_exists: - raise ValueError(f"{key} is already registered for specifiable") - - if subspace not in KNOWN_SPECIFIABLE: + subspace = _class_to_subspace(cls) + if subspace in KNOWN_SPECIFIABLE: + if spec_type in KNOWN_SPECIFIABLE[subspace] and error_if_exists: + raise ValueError(f"{spec_type} is already registered for specifiable") + else: KNOWN_SPECIFIABLE[subspace] = {} - KNOWN_SPECIFIABLE[subspace][key] = cls + KNOWN_SPECIFIABLE[subspace][spec_type] = cls - cls._key = key + cls.spec_type = spec_type -def track_init_params(inst, init_method, *args, **kwargs): +# Keep a copy of arguments that are used to call __init__ method, when the +# object is initialized. +def _get_init_kwargs(inst, init_method, *args, **kwargs): params = dict( zip(inspect.signature(init_method).parameters.keys(), (None, ) + args)) del params['self'] params.update(**kwargs) - inst._init_params = params + return params def specifiable( my_cls=None, /, *, - key=None, + spec_type=None, error_if_exists=True, on_demand_init=True, just_in_time_init=True): - - # register a specifiable, track init params for each instance, lazy init + """A decorator that turns a class into a Specifiable subclass by implementing + the Specifiable protocol. + + To use the decorator, simply place `@specifiable` before the class definition. + For finer control, the decorator accepts arguments + (e.g., `@specifiable(arg1=..., arg2=...)`). + + Args: + spec_type: The value of the `type` field in the Spec of a Specifiable + subclass. If not provided, the class name is used. + error_if_exists: If True, raise an exception if `spec_type` is already + registered. + on_demand_init: If True, allow on-demand object initialization. The original + `__init__` method will be called when `_run_init=True` is passed to the + object's initialization function. + just_in_time_init: If True, allow just-in-time object initialization. The + original `__init__` method will be called when an attribute is first + accessed. + """ def _wrapper(cls): - register(cls, key, error_if_exists) - - original_init = cls.__init__ - class_name = cls.__name__ - - def new_init(self, *args, **kwargs): + def new_init(self: Specifiable, *args, **kwargs): self._initialized = False - #self._nested_getattr = False - - if kwargs.get("_run_init", False): - run_init = True - del kwargs['_run_init'] - else: - run_init = False - - if '_init_params' not in self.__dict__: - track_init_params(self, original_init, *args, **kwargs) - - # If it is not a nested specifiable, we choose whether to skip original - # init call based on options. Otherwise, we always call original init - # for inner (parent/grandparent/etc) specifiable. - if (on_demand_init and not run_init) or \ + self._in_init = False + + run_init_request = False + if "_run_init" in kwargs: + run_init_request = kwargs["_run_init"] + del kwargs["_run_init"] + + if 'init_kwargs' not in self.__dict__: + # If it is a child specifiable (i.e.g init_kwargs not set), we determine + # whether to skip the original __init__ call based on options: + # on_demand_init, just_in_time_init and _run_init. + # Otherwise (i.e. init_kwargs is set), we always call the original + # __init__ method for ancestor specifiable. + self.init_kwargs = _get_init_kwargs( + self, original_init, *args, **kwargs) + logging.debug("Record init params in %s.new_init", class_name) + + if (on_demand_init and not run_init_request) or \ (not on_demand_init and just_in_time_init): + logging.debug("Skip original %s.__init__", class_name) return - logging.debug("call original %s.__init__ in new_init", class_name) + logging.debug("Call original %s.__init__ in new_init", class_name) + original_init(self, *args, **kwargs) self._initialized = True - def run_init(self): - original_init(self, **self._init_params) + def run_original_init(self): + self._in_init = True + original_init(self, **self.init_kwargs) + self._in_init = False + self._initialized = True + # __getattr__ is only called when an attribute is not found in the object def new_getattr(self, name): - if name == '_nested_getattr' or \ - ('_nested_getattr' in self.__dict__ and self._nested_getattr): - #self._nested_getattr = False - delattr(self, "_nested_getattr") + logging.debug( + "Trying to access %s.%s, but it is not found.", class_name, name) + + # Fix the infinite loop issue when pickling a Specifiable + if name in ["_in_init", "__getstate__"] and name not in self.__dict__: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'") - # set it before original init, in case getattr is called in original init - self._nested_getattr = True + # If the attribute is not found during or after initialization, then + # it is a missing attribute. + if self._in_init or self._initialized: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") - if not self._initialized and name != "__getstate__": - logging.debug("call original %s.__init__ in new_getattr", class_name) - original_init(self, **self._init_params) - self._initialized = True + # Here, we know the object is not initialized, then we will call original + # init method. + logging.debug("Call original %s.__init__ in new_getattr", class_name) + run_original_init(self) - try: - logging.debug("call original %s.getattr in new_getattr", class_name) - ret = getattr(self, name) - finally: - # self._nested_getattr = False - delattr(self, "_nested_getattr") - return ret + # __getattribute__ is call for every attribute regardless whether it is + # present in the object. In this case, we don't cause an infinite loop + # if the attribute does not exist. + logging.debug( + "Call original %s.__getattribute__(%s) in new_getattr", + class_name, + name) + return self.__getattribute__(name) + # start of the function body of _wrapper + _register(cls, spec_type, error_if_exists) + + class_name = cls.__name__ + original_init = cls.__init__ + cls.__init__ = new_init if just_in_time_init: cls.__getattr__ = new_getattr - cls.__init__ = new_init - cls._run_init = run_init + cls.run_original_init = run_original_init cls.to_spec = Specifiable.to_spec cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) cls.from_spec = classmethod(Specifiable.from_spec) cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) return cls + # end of the function body of _wrapper + # When this decorator is called with arguments, i.e.. + # "@specifiable(arg1=...,arg2=...)", it is equivalent to assigning + # specifiable(arg1=..., arg2=...) to a variable, say decor_func, and then + # calling "@decor_func". if my_cls is None: - # support @specifiable(...) return _wrapper - # support @specifiable without arguments + # When this decorator is called without an argument, i.e. "@specifiable", + # we return the augmented class. return _wrapper(my_cls) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 71910b03ed5b..108dca49df7d 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -30,69 +30,83 @@ class TestSpecifiable(unittest.TestCase): - def test_register_specifiable(self): - class MyClass(): + def test_decorator_in_function_form(self): + class A(): pass - # class is not decorated/registered - self.assertRaises(AttributeError, lambda: MyClass().to_spec()) - - self.assertNotIn("MyKey", KNOWN_SPECIFIABLE["*"]) - - MyClass = specifiable(key="MyKey")(MyClass) - - self.assertIn("MyKey", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["MyKey"], MyClass) - - # By default, an error is raised if the key is duplicated - self.assertRaises(ValueError, specifiable(key="MyKey"), MyClass) - - # But it is ok if a different key is used for the same class - _ = specifiable(key="MyOtherKey")(MyClass) - self.assertIn("MyOtherKey", KNOWN_SPECIFIABLE["*"]) - - # Or, use a parameter to suppress the error - specifiable(key="MyKey", error_if_exists=False)(MyClass) - - def test_decorator_key(self): - # use decorator without parameter + # class is not decorated and thus not registered + self.assertNotIn("A", KNOWN_SPECIFIABLE["*"]) + + # apply the decorator function to an existing class + A = specifiable(A) + self.assertEqual(A.spec_type, "A") + self.assertTrue(isinstance(A(), Specifiable)) + self.assertIn("A", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["A"], A) + + # an error is raised if the specified spec_type already exists. + self.assertRaises(ValueError, specifiable, A) + + # apply the decorator function to an existing class with a different + # spec_type + A = specifiable(spec_type="A_DUP")(A) + self.assertEqual(A.spec_type, "A_DUP") + self.assertTrue(isinstance(A(), Specifiable)) + self.assertIn("A_DUP", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["A_DUP"], A) + + # an error is raised if the specified spec_type already exists. + self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) + + # but the error can be suppressed by setting error_if_exists=False. + try: + specifiable(spec_type="A_DUP", error_if_exists=False)(A) + except ValueError: + self.fail("The ValueError should be suppressed but instead it is raised.") + + def test_decorator_in_syntactic_sugar_form(self): + # call decorator without parameters @specifiable - class MySecondClass(): + class B(): pass - self.assertIn("MySecondClass", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["MySecondClass"], MySecondClass) - self.assertTrue(isinstance(MySecondClass(), Specifiable)) + self.assertTrue(isinstance(B(), Specifiable)) + self.assertIn("B", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["B"], B) - # use decorator with key parameter - @specifiable(key="MyThirdKey") - class MyThirdClass(): + # call decorator with parameters + @specifiable(spec_type="C_TYPE") + class C(): pass - self.assertIn("MyThirdKey", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["MyThirdKey"], MyThirdClass) + self.assertTrue(isinstance(C(), Specifiable)) + self.assertIn("C_TYPE", KNOWN_SPECIFIABLE["*"]) + self.assertEqual(KNOWN_SPECIFIABLE["*"]["C_TYPE"], C) def test_init_params_in_specifiable(self): @specifiable - class MyClassWithInitParams(): + class ParentWithInitParams(): def __init__(self, arg_1, arg_2=2, arg_3="3", **kwargs): pass - a = MyClassWithInitParams(10, arg_3="30", arg_4=40) - assert isinstance(a, Specifiable) - self.assertEqual(a._init_params, {'arg_1': 10, 'arg_3': '30', 'arg_4': 40}) + parent = ParentWithInitParams(10, arg_3="30", arg_4=40) + assert isinstance(parent, Specifiable) + self.assertEqual( + parent.init_kwargs, { + 'arg_1': 10, 'arg_3': '30', 'arg_4': 40 + }) - # inheritance of specifiable + # inheritance of a Specifiable subclass @specifiable - class MyDerivedClassWithInitParams(MyClassWithInitParams): + class ChildWithInitParams(ParentWithInitParams): def __init__(self, new_arg_1, new_arg_2=200, new_arg_3="300", **kwargs): super().__init__(**kwargs) - b = MyDerivedClassWithInitParams( + child = ChildWithInitParams( 1000, arg_1=11, arg_2=20, new_arg_2=2000, arg_4=4000) - assert isinstance(b, Specifiable) + assert isinstance(child, Specifiable) self.assertEqual( - b._init_params, + child.init_kwargs, { 'new_arg_1': 1000, 'arg_1': 11, @@ -101,25 +115,44 @@ def __init__(self, new_arg_1, new_arg_2=200, new_arg_3="300", **kwargs): 'arg_4': 4000 }) - # composite of specifiable + # composite of Specifiable subclasses @specifiable - class MyCompositeClassWithInitParams(): - def __init__(self, my_class: Optional[MyClassWithInitParams] = None): + class CompositeWithInitParams(): + def __init__( + self, + my_parent: Optional[ParentWithInitParams] = None, + my_child: Optional[ChildWithInitParams] = None): pass - c = MyCompositeClassWithInitParams(a) - assert isinstance(c, Specifiable) - self.assertEqual(c._init_params, {'my_class': a}) + composite = CompositeWithInitParams(parent, child) + assert isinstance(composite, Specifiable) + self.assertEqual( + composite.init_kwargs, { + 'my_parent': parent, 'my_child': child + }) - def test_from_and_to_specifiable(self): - @specifiable(on_demand_init=False, just_in_time_init=False) + def test_from_spec_on_unknown_spec_type(self): + self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) + + # To test from_spec and to_spec with/without just_in_time_init. + @parameterized.expand([(False, False), (True, False), (False, True), + (True, True)]) + def test_from_spec_and_to_spec(self, on_demand_init, just_in_time_init): + @specifiable( + spec_type=f"product_{just_in_time_init}", + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) @dataclasses.dataclass class Product(): name: str price: float @specifiable( - key="shopping_entry", on_demand_init=False, just_in_time_init=False) + spec_type=f"shopping_entry_{just_in_time_init}", + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) class Entry(): def __init__(self, product: Product, quantity: int = 1): self._product = product @@ -131,7 +164,10 @@ def __eq__(self, value) -> bool: self._quantity == value._quantity @specifiable( - key="shopping_cart", on_demand_init=False, just_in_time_init=False) + spec_type=f"shopping_cart_{just_in_time_init}", + on_demand_init=on_demand_init, + just_in_time_init=just_in_time_init, + error_if_exists=False) @dataclasses.dataclass class ShoppingCart(): user_id: str @@ -140,39 +176,38 @@ class ShoppingCart(): orange = Product("orange", 1.0) expected_orange_spec = Spec( - "Product", config={ + f"product_{just_in_time_init}", config={ 'name': 'orange', 'price': 1.0 }) assert isinstance(orange, Specifiable) self.assertEqual(orange.to_spec(), expected_orange_spec) - self.assertEqual(Specifiable.from_spec(expected_orange_spec), orange) entry_1 = Entry(product=orange) expected_entry_spec_1 = Spec( - "shopping_entry", config={ + f"shopping_entry_{just_in_time_init}", + config={ 'product': expected_orange_spec, }) assert isinstance(entry_1, Specifiable) self.assertEqual(entry_1.to_spec(), expected_entry_spec_1) - self.assertEqual(Specifiable.from_spec(expected_entry_spec_1), entry_1) banana = Product("banana", 0.5) expected_banana_spec = Spec( - "Product", config={ + f"product_{just_in_time_init}", config={ 'name': 'banana', 'price': 0.5 }) entry_2 = Entry(product=banana, quantity=5) expected_entry_spec_2 = Spec( - "shopping_entry", + f"shopping_entry_{just_in_time_init}", config={ 'product': expected_banana_spec, 'quantity': 5 }) shopping_cart = ShoppingCart(user_id="test", entries=[entry_1, entry_2]) expected_shopping_cart_spec = Spec( - "shopping_cart", + f"shopping_cart_{just_in_time_init}", config={ "user_id": "test", "entries": [expected_entry_spec_1, expected_entry_spec_2] @@ -180,9 +215,20 @@ class ShoppingCart(): assert isinstance(shopping_cart, Specifiable) self.assertEqual(shopping_cart.to_spec(), expected_shopping_cart_spec) + if on_demand_init and not just_in_time_init: + orange.run_original_init() + banana.run_original_init() + entry_1.run_original_init() + entry_2.run_original_init() + shopping_cart.run_original_init() + + self.assertEqual(Specifiable.from_spec(expected_orange_spec), orange) + self.assertEqual(Specifiable.from_spec(expected_entry_spec_1), entry_1) self.assertEqual( Specifiable.from_spec(expected_shopping_cart_spec), shopping_cart) + +class TestInitCallCount(unittest.TestCase): def test_on_demand_init(self): @specifiable(on_demand_init=True, just_in_time_init=False) class FooOnDemand(): @@ -190,12 +236,12 @@ class FooOnDemand(): def __init__(self, arg): self.my_arg = arg * 10 - FooOnDemand.counter += 1 + FooOnDemand.counter += 1 # increment it when __init__ is called foo = FooOnDemand(123) self.assertEqual(FooOnDemand.counter, 0) - self.assertIn("_init_params", foo.__dict__) - self.assertEqual(foo.__dict__["_init_params"], {"arg": 123}) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 123}) self.assertNotIn("my_arg", foo.__dict__) self.assertRaises(AttributeError, getattr, foo, "my_arg") @@ -204,10 +250,11 @@ def __init__(self, arg): self.assertRaises(AttributeError, lambda: foo.unknown_arg) self.assertEqual(FooOnDemand.counter, 0) + # __init__ is called when _run_init=True is used foo_2 = FooOnDemand(456, _run_init=True) self.assertEqual(FooOnDemand.counter, 1) - self.assertIn("_init_params", foo_2.__dict__) - self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 456}) + self.assertIn("init_kwargs", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["init_kwargs"], {"arg": 456}) self.assertIn("my_arg", foo_2.__dict__) self.assertEqual(foo_2.my_arg, 4560) @@ -220,17 +267,18 @@ class FooJustInTime(): def __init__(self, arg): self.my_arg = arg * 10 - FooJustInTime.counter += 1 + FooJustInTime.counter += 1 # increment it when __init__ is called foo = FooJustInTime(321) self.assertEqual(FooJustInTime.counter, 0) - self.assertIn("_init_params", foo.__dict__) - self.assertEqual(foo.__dict__["_init_params"], {"arg": 321}) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 321}) - self.assertNotIn("my_arg", foo.__dict__) # __init__ hasn't been called + # __init__ hasn't been called yet + self.assertNotIn("my_arg", foo.__dict__) self.assertEqual(FooJustInTime.counter, 0) - # __init__ is called when trying to accessing an attribute + # __init__ is called when trying to access a class attribute self.assertEqual(foo.my_arg, 3210) self.assertEqual(FooJustInTime.counter, 1) self.assertRaises(AttributeError, lambda: foo.unknown_arg) @@ -247,23 +295,23 @@ def __init__(self, arg): foo = FooOnDemandAndJustInTime(987) self.assertEqual(FooOnDemandAndJustInTime.counter, 0) - self.assertIn("_init_params", foo.__dict__) - self.assertEqual(foo.__dict__["_init_params"], {"arg": 987}) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 987}) self.assertNotIn("my_arg", foo.__dict__) self.assertEqual(FooOnDemandAndJustInTime.counter, 0) - # __init__ is called + # __init__ is called when trying to access a class attribute self.assertEqual(foo.my_arg, 9870) self.assertEqual(FooOnDemandAndJustInTime.counter, 1) - # __init__ is called + # __init__ is called when _run_init=True is used foo_2 = FooOnDemandAndJustInTime(789, _run_init=True) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) - self.assertIn("_init_params", foo_2.__dict__) - self.assertEqual(foo_2.__dict__["_init_params"], {"arg": 789}) + self.assertIn("init_kwargs", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["init_kwargs"], {"arg": 789}) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) - # __init__ is NOT called + # __init__ is NOT called after it is initialized self.assertEqual(foo_2.my_arg, 7890) self.assertEqual(FooOnDemandAndJustInTime.counter, 2) @@ -276,7 +324,7 @@ def __init__(self, arg): type(self).counter += 1 def test_on_pickle(self): - FooForPickle = TestSpecifiable.FooForPickle + FooForPickle = TestInitCallCount.FooForPickle import dill FooForPickle.counter = 0 @@ -357,6 +405,7 @@ class Child_Error_1(Parent): child_class_var = 2001 def __init__(self, c): + # read an instance var in child that doesn't exist self.child_inst_var += 1 super().__init__(c) Child_2.counter += 1 @@ -368,6 +417,7 @@ class Child_Error_2(Parent): child_class_var = 2001 def __init__(self, c): + # read an instance var in parent without calling parent's __init__. self.parent_inst_var += 1 Child_2.counter += 1 From ebfa85eca650bf0cd4d076fd2977aec0db00a610 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Fri, 7 Feb 2025 13:04:38 -0500 Subject: [PATCH 07/11] Minor changes to docstrings and comments --- sdks/python/apache_beam/ml/anomaly/base.py | 2 +- .../apache_beam/ml/anomaly/specifiable.py | 62 +++++++++++-------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py index dfe29ee55ee9..6a717cf5db16 100644 --- a/sdks/python/apache_beam/ml/anomaly/base.py +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -196,7 +196,7 @@ def __init__( def learn_one(self, x: beam.Row) -> None: """Inherited from `AnomalyDetector.learn_one`. - This method is never called during ensemble detector training. The training + This method is never called during ensemble detector training. The training process is done on each sub-detector independently and in parallel. """ raise NotImplementedError diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 7765f8998ada..74e3b661cd20 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -45,10 +45,10 @@ ] #: A nested dictionary for efficient lookup of Specifiable subclasses. -#: Structure: KNOWN_SPECIFIABLE[subspace][spec_type], where "subspace" is one of -#: the accepted subspaces that the class belongs to and "spec_type" is the class -#: name by default. Users can also specify a different value for "spec_type" -#: when applying the `specifiable` decorator to an existing class. +#: Structure: `KNOWN_SPECIFIABLE[subspace][spec_type]`, where `subspace` is one +#: of the accepted subspaces that the class belongs to and `spec_type` is the +#: class name by default. Users can also specify a different value for +#: `spec_type` when applying the `specifiable` decorator to an existing class. KNOWN_SPECIFIABLE = {"*": {}} SpecT = TypeVar('SpecT', bound='Specifiable') @@ -58,7 +58,7 @@ def _class_to_subspace(cls: Type, default="*") -> str: """ Search the class hierarchy to find the subspace: the closest ancestor class in the class's method resolution order (MRO) whose name is found in the accepted - subspace list. This is usually called when registering a new Specifiable + subspace list. This is usually called when registering a new specifiable class. """ for c in cls.mro(): @@ -75,7 +75,7 @@ def _class_to_subspace(cls: Type, default="*") -> str: def _spec_type_to_subspace(type: str, default="*") -> str: """ Look for the subspace for a spec type. This is usually called to retrieve - the subspace of a registered Specifiable class. + the subspace of a registered specifiable class. """ for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: if type in KNOWN_SPECIFIABLE.get(subspace, {}): @@ -92,9 +92,9 @@ class Spec(): """ Dataclass for storing specifications of specifiable objects. Objects can be initialized using the data in their corresponding spec. - The `type` field indicates the concrete Specifiable class, while + The `type` field indicates the concrete `Specifiable` class, while """ - #: A string indicating the concrete Specifiable class + #: A string indicating the concrete `Specifiable` class type: str #: A dictionary of keyword arguments for the `__init__` method of the class. config: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -102,19 +102,20 @@ class Spec(): @runtime_checkable class Specifiable(Protocol): - """Protocol that a Specifiable subclass needs to implement. + """Protocol that a specifiable class needs to implement. Attributes: - spec_type: The value of the `type` field in the object's Spec for this + spec_type: The value of the `type` field in the object's spec for this class. - init_kwargs: The raw keyword arguments passed to `__init__` during object - initialization. + init_kwargs: The raw keyword arguments passed to `__init__` method during + object initialization. """ spec_type: ClassVar[str] init_kwargs: dict[str, Any] - # a boolean to tell whether the original __init__ is called + # a boolean to tell whether the original `__init__` method is called _initialized: bool - # a boolean used by new_getattr to tell whether it is in an __init__ call + # a boolean used by new_getattr to tell whether it is in the `__init__` method + # call _in_init: bool @staticmethod @@ -129,7 +130,7 @@ def _from_spec_helper(v, _run_init): @classmethod def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: - """Generate a Specifiable subclass object based on a spec.""" + """Generate a `Specifiable` subclass object based on a spec.""" if spec.type is None: raise ValueError(f"Spec type not found in {spec}") @@ -160,7 +161,7 @@ def _to_spec_helper(v): def to_spec(self) -> Spec: """ - Generate a spec from a Specifiable subclass object. + Generate a spec from a `Specifiable` subclass object. """ if getattr(type(self), 'spec_type', None) is None: raise ValueError( @@ -172,7 +173,7 @@ def to_spec(self) -> Spec: return Spec(type=self.__class__.spec_type, config=args) -# Register a Specifiable subclass in KNOWN_SPECIFIABLE +# Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE` def _register(cls, spec_type=None, error_if_exists=True) -> None: if spec_type is None: # By default, spec type is the class name. Users can override this with @@ -190,7 +191,7 @@ def _register(cls, spec_type=None, error_if_exists=True) -> None: cls.spec_type = spec_type -# Keep a copy of arguments that are used to call __init__ method, when the +# Keep a copy of arguments that are used to call the `__init__` method when the # object is initialized. def _get_init_kwargs(inst, init_method, *args, **kwargs): params = dict( @@ -208,15 +209,24 @@ def specifiable( error_if_exists=True, on_demand_init=True, just_in_time_init=True): - """A decorator that turns a class into a Specifiable subclass by implementing - the Specifiable protocol. + """A decorator that turns a class into a `Specifiable` subclass by + implementing the `Specifiable` protocol. - To use the decorator, simply place `@specifiable` before the class definition. - For finer control, the decorator accepts arguments - (e.g., `@specifiable(arg1=..., arg2=...)`). + To use the decorator, simply place `@specifiable` before the class + definition.:: + + @specifiable + class Foo(): + ... + + For finer control, the decorator can accept arguments.:: + + @specifiable(spec_type="My Class", on_demand_init=False) + class Bar(): + ... Args: - spec_type: The value of the `type` field in the Spec of a Specifiable + spec_type: The value of the `type` field in the Spec of a `Specifiable` subclass. If not provided, the class name is used. error_if_exists: If True, raise an exception if `spec_type` is already registered. @@ -224,8 +234,8 @@ def specifiable( `__init__` method will be called when `_run_init=True` is passed to the object's initialization function. just_in_time_init: If True, allow just-in-time object initialization. The - original `__init__` method will be called when an attribute is first - accessed. + original `__init__` method will be called when the first time an attribute + is accessed. """ def _wrapper(cls): def new_init(self: Specifiable, *args, **kwargs): From 5f0debf63985f75ac210deb21db2998ede128056 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Fri, 7 Feb 2025 15:33:14 -0500 Subject: [PATCH 08/11] Remove the fallback subspace '*' from accepted list. Use it in tests only. --- .../apache_beam/ml/anomaly/specifiable.py | 29 +++++++++---------- .../ml/anomaly/specifiable_test.py | 23 +++++++++------ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 74e3b661cd20..cb93f60209f6 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -36,25 +36,28 @@ __all__ = ["KNOWN_SPECIFIABLE", "Spec", "Specifiable", "specifiable"] -ACCEPTED_SPECIFIABLE_SUBSPACES = [ +ACCEPTED_SUBSPACES = [ "EnsembleAnomalyDetector", "AnomalyDetector", "ThresholdFn", "AggregationFn", - "*" ] +# By default, the fallback subspace is not in the accepted subspace list. +# We only use this fallback subspace in tests. +FALLBACK_SUBSPACE = "my test subspace" + #: A nested dictionary for efficient lookup of Specifiable subclasses. #: Structure: `KNOWN_SPECIFIABLE[subspace][spec_type]`, where `subspace` is one #: of the accepted subspaces that the class belongs to and `spec_type` is the #: class name by default. Users can also specify a different value for #: `spec_type` when applying the `specifiable` decorator to an existing class. -KNOWN_SPECIFIABLE = {"*": {}} +KNOWN_SPECIFIABLE = {} SpecT = TypeVar('SpecT', bound='Specifiable') -def _class_to_subspace(cls: Type, default="*") -> str: +def _class_to_subspace(cls: Type) -> str: """ Search the class hierarchy to find the subspace: the closest ancestor class in the class's method resolution order (MRO) whose name is found in the accepted @@ -62,29 +65,25 @@ def _class_to_subspace(cls: Type, default="*") -> str: class. """ for c in cls.mro(): - # - if c.__name__ in ACCEPTED_SPECIFIABLE_SUBSPACES: + if c.__name__ in ACCEPTED_SUBSPACES: return c.__name__ - if default is None: - raise ValueError(f"subspace for {cls.__name__} not found.") + if FALLBACK_SUBSPACE in ACCEPTED_SUBSPACES: + return FALLBACK_SUBSPACE - return default + raise ValueError(f"subspace for {cls.__name__} not found.") -def _spec_type_to_subspace(type: str, default="*") -> str: +def _spec_type_to_subspace(type: str) -> str: """ Look for the subspace for a spec type. This is usually called to retrieve the subspace of a registered specifiable class. """ - for subspace in ACCEPTED_SPECIFIABLE_SUBSPACES: + for subspace in ACCEPTED_SUBSPACES: if type in KNOWN_SPECIFIABLE.get(subspace, {}): return subspace - if default is None: - raise ValueError(f"subspace for {type} not found.") - - return default + raise ValueError(f"subspace for {str} not found.") @dataclasses.dataclass(frozen=True) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 108dca49df7d..2dee66d2720b 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -23,11 +23,16 @@ from parameterized import parameterized +from apache_beam.ml.anomaly.specifiable import ACCEPTED_SUBSPACES +from apache_beam.ml.anomaly.specifiable import FALLBACK_SUBSPACE from apache_beam.ml.anomaly.specifiable import KNOWN_SPECIFIABLE from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.specifiable import specifiable +# The fallback subspace is only accepted during testing. +ACCEPTED_SUBSPACES.append(FALLBACK_SUBSPACE) + class TestSpecifiable(unittest.TestCase): def test_decorator_in_function_form(self): @@ -35,14 +40,14 @@ class A(): pass # class is not decorated and thus not registered - self.assertNotIn("A", KNOWN_SPECIFIABLE["*"]) + self.assertNotIn("A", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) # apply the decorator function to an existing class A = specifiable(A) self.assertEqual(A.spec_type, "A") self.assertTrue(isinstance(A(), Specifiable)) - self.assertIn("A", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["A"], A) + self.assertIn("A", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) + self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["A"], A) # an error is raised if the specified spec_type already exists. self.assertRaises(ValueError, specifiable, A) @@ -52,8 +57,8 @@ class A(): A = specifiable(spec_type="A_DUP")(A) self.assertEqual(A.spec_type, "A_DUP") self.assertTrue(isinstance(A(), Specifiable)) - self.assertIn("A_DUP", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["A_DUP"], A) + self.assertIn("A_DUP", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) + self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["A_DUP"], A) # an error is raised if the specified spec_type already exists. self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) @@ -71,8 +76,8 @@ class B(): pass self.assertTrue(isinstance(B(), Specifiable)) - self.assertIn("B", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["B"], B) + self.assertIn("B", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) + self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["B"], B) # call decorator with parameters @specifiable(spec_type="C_TYPE") @@ -80,8 +85,8 @@ class C(): pass self.assertTrue(isinstance(C(), Specifiable)) - self.assertIn("C_TYPE", KNOWN_SPECIFIABLE["*"]) - self.assertEqual(KNOWN_SPECIFIABLE["*"]["C_TYPE"], C) + self.assertIn("C_TYPE", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) + self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["C_TYPE"], C) def test_init_params_in_specifiable(self): @specifiable From 38b0a89b73f58e6109d92aca961ecc49eb152cb4 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Fri, 7 Feb 2025 16:31:58 -0500 Subject: [PATCH 09/11] Bring fallback subspace back to accepted list. Clarify the use of spec_type to resolve naming conclict. --- .../apache_beam/ml/anomaly/specifiable.py | 49 ++++++++++--------- .../ml/anomaly/specifiable_test.py | 26 +++++----- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index cb93f60209f6..cf9044bdc135 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -34,25 +34,24 @@ from typing_extensions import Self -__all__ = ["KNOWN_SPECIFIABLE", "Spec", "Specifiable", "specifiable"] +__all__ = ["Spec", "Specifiable", "specifiable"] -ACCEPTED_SUBSPACES = [ +_FALLBACK_SUBSPACE = "*" + +_ACCEPTED_SUBSPACES = [ "EnsembleAnomalyDetector", "AnomalyDetector", "ThresholdFn", "AggregationFn", + _FALLBACK_SUBSPACE, ] -# By default, the fallback subspace is not in the accepted subspace list. -# We only use this fallback subspace in tests. -FALLBACK_SUBSPACE = "my test subspace" - #: A nested dictionary for efficient lookup of Specifiable subclasses. -#: Structure: `KNOWN_SPECIFIABLE[subspace][spec_type]`, where `subspace` is one +#: Structure: `_KNOWN_SPECIFIABLE[subspace][spec_type]`, where `subspace` is one #: of the accepted subspaces that the class belongs to and `spec_type` is the #: class name by default. Users can also specify a different value for #: `spec_type` when applying the `specifiable` decorator to an existing class. -KNOWN_SPECIFIABLE = {} +_KNOWN_SPECIFIABLE = {} SpecT = TypeVar('SpecT', bound='Specifiable') @@ -65,13 +64,10 @@ def _class_to_subspace(cls: Type) -> str: class. """ for c in cls.mro(): - if c.__name__ in ACCEPTED_SUBSPACES: + if c.__name__ in _ACCEPTED_SUBSPACES: return c.__name__ - if FALLBACK_SUBSPACE in ACCEPTED_SUBSPACES: - return FALLBACK_SUBSPACE - - raise ValueError(f"subspace for {cls.__name__} not found.") + return _FALLBACK_SUBSPACE def _spec_type_to_subspace(type: str) -> str: @@ -79,8 +75,8 @@ def _spec_type_to_subspace(type: str) -> str: Look for the subspace for a spec type. This is usually called to retrieve the subspace of a registered specifiable class. """ - for subspace in ACCEPTED_SUBSPACES: - if type in KNOWN_SPECIFIABLE.get(subspace, {}): + for subspace in _ACCEPTED_SUBSPACES: + if type in _KNOWN_SPECIFIABLE.get(subspace, {}): return subspace raise ValueError(f"subspace for {str} not found.") @@ -91,7 +87,6 @@ class Spec(): """ Dataclass for storing specifications of specifiable objects. Objects can be initialized using the data in their corresponding spec. - The `type` field indicates the concrete `Specifiable` class, while """ #: A string indicating the concrete `Specifiable` class type: str @@ -111,6 +106,7 @@ class Specifiable(Protocol): """ spec_type: ClassVar[str] init_kwargs: dict[str, Any] + # a boolean to tell whether the original `__init__` method is called _initialized: bool # a boolean used by new_getattr to tell whether it is in the `__init__` method @@ -134,7 +130,7 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: raise ValueError(f"Spec type not found in {spec}") subspace = _spec_type_to_subspace(spec.type) - subclass: Type[Self] = KNOWN_SPECIFIABLE[subspace].get(spec.type, None) + subclass: Type[Self] = _KNOWN_SPECIFIABLE[subspace].get(spec.type, None) if subclass is None: raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") @@ -180,12 +176,17 @@ def _register(cls, spec_type=None, error_if_exists=True) -> None: spec_type = cls.__name__ subspace = _class_to_subspace(cls) - if subspace in KNOWN_SPECIFIABLE: - if spec_type in KNOWN_SPECIFIABLE[subspace] and error_if_exists: - raise ValueError(f"{spec_type} is already registered for specifiable") + if subspace in _KNOWN_SPECIFIABLE: + if spec_type in _KNOWN_SPECIFIABLE[subspace] and error_if_exists: + raise ValueError( + f"{spec_type} is already registered for " + f"specifiable class {_KNOWN_SPECIFIABLE[subspace]}. " + "Please specify a different spec_type by " + "@specifiable(spec_type=...) or ignore the error by " + "@specifiable(error_if_exists=False).") else: - KNOWN_SPECIFIABLE[subspace] = {} - KNOWN_SPECIFIABLE[subspace][spec_type] = cls + _KNOWN_SPECIFIABLE[subspace] = {} + _KNOWN_SPECIFIABLE[subspace][spec_type] = cls cls.spec_type = spec_type @@ -226,7 +227,9 @@ class Bar(): Args: spec_type: The value of the `type` field in the Spec of a `Specifiable` - subclass. If not provided, the class name is used. + subclass. If not provided, the class name is used. This argument is useful + when registering multiple classes with the same base name; in such cases, + one can specify `spec_type` to different values to resolve conflict. error_if_exists: If True, raise an exception if `spec_type` is already registered. on_demand_init: If True, allow on-demand object initialization. The original diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 2dee66d2720b..62dc6dd5815c 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -23,16 +23,12 @@ from parameterized import parameterized -from apache_beam.ml.anomaly.specifiable import ACCEPTED_SUBSPACES -from apache_beam.ml.anomaly.specifiable import FALLBACK_SUBSPACE -from apache_beam.ml.anomaly.specifiable import KNOWN_SPECIFIABLE +from apache_beam.ml.anomaly.specifiable import _FALLBACK_SUBSPACE +from apache_beam.ml.anomaly.specifiable import _KNOWN_SPECIFIABLE from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.specifiable import specifiable -# The fallback subspace is only accepted during testing. -ACCEPTED_SUBSPACES.append(FALLBACK_SUBSPACE) - class TestSpecifiable(unittest.TestCase): def test_decorator_in_function_form(self): @@ -40,14 +36,14 @@ class A(): pass # class is not decorated and thus not registered - self.assertNotIn("A", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) + self.assertNotIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) # apply the decorator function to an existing class A = specifiable(A) self.assertEqual(A.spec_type, "A") self.assertTrue(isinstance(A(), Specifiable)) - self.assertIn("A", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) - self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["A"], A) + self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A) # an error is raised if the specified spec_type already exists. self.assertRaises(ValueError, specifiable, A) @@ -57,8 +53,8 @@ class A(): A = specifiable(spec_type="A_DUP")(A) self.assertEqual(A.spec_type, "A_DUP") self.assertTrue(isinstance(A(), Specifiable)) - self.assertIn("A_DUP", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) - self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["A_DUP"], A) + self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A) # an error is raised if the specified spec_type already exists. self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) @@ -76,8 +72,8 @@ class B(): pass self.assertTrue(isinstance(B(), Specifiable)) - self.assertIn("B", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) - self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["B"], B) + self.assertIn("B", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["B"], B) # call decorator with parameters @specifiable(spec_type="C_TYPE") @@ -85,8 +81,8 @@ class C(): pass self.assertTrue(isinstance(C(), Specifiable)) - self.assertIn("C_TYPE", KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]) - self.assertEqual(KNOWN_SPECIFIABLE[FALLBACK_SUBSPACE]["C_TYPE"], C) + self.assertIn("C_TYPE", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["C_TYPE"], C) def test_init_params_in_specifiable(self): @specifiable From 9dba6254efb3a7c9d6b5e9836c6860e56d0e713d Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Fri, 7 Feb 2025 19:57:36 -0500 Subject: [PATCH 10/11] Make _KNOWN_SPECIFIABLE a defaultdict. Remove error_if_exiists. --- .../apache_beam/ml/anomaly/base_test.py | 29 +++++++-------- .../apache_beam/ml/anomaly/specifiable.py | 30 ++++++--------- .../ml/anomaly/specifiable_test.py | 37 ++++++------------- 3 files changed, 38 insertions(+), 58 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/base_test.py b/sdks/python/apache_beam/ml/anomaly/base_test.py index 715d5128ee1e..e58674d8c1e9 100644 --- a/sdks/python/apache_beam/ml/anomaly/base_test.py +++ b/sdks/python/apache_beam/ml/anomaly/base_test.py @@ -26,19 +26,22 @@ from apache_beam.ml.anomaly.base import AnomalyDetector from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector from apache_beam.ml.anomaly.base import ThresholdFn +from apache_beam.ml.anomaly.specifiable import _KNOWN_SPECIFIABLE from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.specifiable import specifiable class TestAnomalyDetector(unittest.TestCase): + def setUp(self) -> None: + # Remove all registered specifiable classes and reset. + _KNOWN_SPECIFIABLE.clear() + @parameterized.expand([(False, False), (True, False), (False, True), (True, True)]) def test_model_id_and_spec(self, on_demand_init, just_in_time_init): @specifiable( - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) class DummyThreshold(ThresholdFn): def __init__(self, my_threshold_arg=None): ... @@ -53,9 +56,7 @@ def apply(self, x): ... @specifiable( - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) class Dummy(AnomalyDetector): def __init__(self, my_arg=None, **kwargs): self._my_arg = my_arg @@ -140,21 +141,21 @@ def __eq__(self, value) -> bool: class TestEnsembleAnomalyDetector(unittest.TestCase): + def setUp(self) -> None: + # Remove all registered specifiable classes and reset. + _KNOWN_SPECIFIABLE.clear() + @parameterized.expand([(False, False), (True, False), (False, True), (True, True)]) def test_model_id_and_spec(self, on_demand_init, just_in_time_init): @specifiable( - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) class DummyAggregation(AggregationFn): def apply(self, x): ... @specifiable( - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) class DummyEnsemble(EnsembleAnomalyDetector): def __init__(self, my_ensemble_arg=None, **kwargs): super().__init__(**kwargs) @@ -171,9 +172,7 @@ def __eq__(self, value) -> bool: self._my_ensemble_arg == value._my_ensemble_arg @specifiable( - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) class DummyWeakLearner(AnomalyDetector): def __init__(self, my_arg=None, **kwargs): super().__init__(**kwargs) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index cf9044bdc135..81dba5c87982 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -21,6 +21,7 @@ from __future__ import annotations +import collections import dataclasses import inspect import logging @@ -51,7 +52,7 @@ #: of the accepted subspaces that the class belongs to and `spec_type` is the #: class name by default. Users can also specify a different value for #: `spec_type` when applying the `specifiable` decorator to an existing class. -_KNOWN_SPECIFIABLE = {} +_KNOWN_SPECIFIABLE = collections.defaultdict(dict) SpecT = TypeVar('SpecT', bound='Specifiable') @@ -70,13 +71,13 @@ def _class_to_subspace(cls: Type) -> str: return _FALLBACK_SUBSPACE -def _spec_type_to_subspace(type: str) -> str: +def _spec_type_to_subspace(spec_type: str) -> str: """ Look for the subspace for a spec type. This is usually called to retrieve the subspace of a registered specifiable class. """ for subspace in _ACCEPTED_SUBSPACES: - if type in _KNOWN_SPECIFIABLE.get(subspace, {}): + if spec_type in _KNOWN_SPECIFIABLE[subspace]: return subspace raise ValueError(f"subspace for {str} not found.") @@ -169,24 +170,20 @@ def to_spec(self) -> Spec: # Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE` -def _register(cls, spec_type=None, error_if_exists=True) -> None: +def _register(cls, spec_type=None) -> None: if spec_type is None: # By default, spec type is the class name. Users can override this with # other unique identifier. spec_type = cls.__name__ subspace = _class_to_subspace(cls) - if subspace in _KNOWN_SPECIFIABLE: - if spec_type in _KNOWN_SPECIFIABLE[subspace] and error_if_exists: - raise ValueError( - f"{spec_type} is already registered for " - f"specifiable class {_KNOWN_SPECIFIABLE[subspace]}. " - "Please specify a different spec_type by " - "@specifiable(spec_type=...) or ignore the error by " - "@specifiable(error_if_exists=False).") + if spec_type in _KNOWN_SPECIFIABLE[subspace]: + raise ValueError( + f"{spec_type} is already registered for " + f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. " + "Please specify a different spec_type by @specifiable(spec_type=...).") else: - _KNOWN_SPECIFIABLE[subspace] = {} - _KNOWN_SPECIFIABLE[subspace][spec_type] = cls + _KNOWN_SPECIFIABLE[subspace][spec_type] = cls cls.spec_type = spec_type @@ -206,7 +203,6 @@ def specifiable( /, *, spec_type=None, - error_if_exists=True, on_demand_init=True, just_in_time_init=True): """A decorator that turns a class into a `Specifiable` subclass by @@ -230,8 +226,6 @@ class Bar(): subclass. If not provided, the class name is used. This argument is useful when registering multiple classes with the same base name; in such cases, one can specify `spec_type` to different values to resolve conflict. - error_if_exists: If True, raise an exception if `spec_type` is already - registered. on_demand_init: If True, allow on-demand object initialization. The original `__init__` method will be called when `_run_init=True` is passed to the object's initialization function. @@ -306,7 +300,7 @@ def new_getattr(self, name): return self.__getattribute__(name) # start of the function body of _wrapper - _register(cls, spec_type, error_if_exists) + _register(cls, spec_type) class_name = cls.__name__ original_init = cls.__init__ diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index 62dc6dd5815c..19b9d81c3d53 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -31,6 +31,10 @@ class TestSpecifiable(unittest.TestCase): + def setUp(self) -> None: + # Remove all registered specifiable classes and reset. + _KNOWN_SPECIFIABLE.clear() + def test_decorator_in_function_form(self): class A(): pass @@ -59,12 +63,6 @@ class A(): # an error is raised if the specified spec_type already exists. self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) - # but the error can be suppressed by setting error_if_exists=False. - try: - specifiable(spec_type="A_DUP", error_if_exists=False)(A) - except ValueError: - self.fail("The ValueError should be suppressed but instead it is raised.") - def test_decorator_in_syntactic_sugar_form(self): # call decorator without parameters @specifiable @@ -140,20 +138,14 @@ def test_from_spec_on_unknown_spec_type(self): (True, True)]) def test_from_spec_and_to_spec(self, on_demand_init, just_in_time_init): @specifiable( - spec_type=f"product_{just_in_time_init}", - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) @dataclasses.dataclass class Product(): name: str price: float @specifiable( - spec_type=f"shopping_entry_{just_in_time_init}", - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) class Entry(): def __init__(self, product: Product, quantity: int = 1): self._product = product @@ -165,10 +157,7 @@ def __eq__(self, value) -> bool: self._quantity == value._quantity @specifiable( - spec_type=f"shopping_cart_{just_in_time_init}", - on_demand_init=on_demand_init, - just_in_time_init=just_in_time_init, - error_if_exists=False) + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) @dataclasses.dataclass class ShoppingCart(): user_id: str @@ -177,7 +166,7 @@ class ShoppingCart(): orange = Product("orange", 1.0) expected_orange_spec = Spec( - f"product_{just_in_time_init}", config={ + "Product", config={ 'name': 'orange', 'price': 1.0 }) assert isinstance(orange, Specifiable) @@ -186,8 +175,7 @@ class ShoppingCart(): entry_1 = Entry(product=orange) expected_entry_spec_1 = Spec( - f"shopping_entry_{just_in_time_init}", - config={ + "Entry", config={ 'product': expected_orange_spec, }) @@ -196,19 +184,18 @@ class ShoppingCart(): banana = Product("banana", 0.5) expected_banana_spec = Spec( - f"product_{just_in_time_init}", config={ + "Product", config={ 'name': 'banana', 'price': 0.5 }) entry_2 = Entry(product=banana, quantity=5) expected_entry_spec_2 = Spec( - f"shopping_entry_{just_in_time_init}", - config={ + "Entry", config={ 'product': expected_banana_spec, 'quantity': 5 }) shopping_cart = ShoppingCart(user_id="test", entries=[entry_1, entry_2]) expected_shopping_cart_spec = Spec( - f"shopping_cart_{just_in_time_init}", + "ShoppingCart", config={ "user_id": "test", "entries": [expected_entry_spec_1, expected_entry_spec_2] From fb1d3b393716550358853d20dfb3d32b9e390444 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Fri, 7 Feb 2025 21:08:41 -0500 Subject: [PATCH 11/11] Minor adjustment on docstrings. --- .../apache_beam/ml/anomaly/specifiable.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index 81dba5c87982..1aedab2e8c21 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -97,15 +97,11 @@ class Spec(): @runtime_checkable class Specifiable(Protocol): - """Protocol that a specifiable class needs to implement. - - Attributes: - spec_type: The value of the `type` field in the object's spec for this - class. - init_kwargs: The raw keyword arguments passed to `__init__` method during - object initialization. - """ + """Protocol that a specifiable class needs to implement.""" + #: The value of the `type` field in the object's spec for this class. spec_type: ClassVar[str] + #: The raw keyword arguments passed to `__init__` method during object + #: initialization. init_kwargs: dict[str, Any] # a boolean to tell whether the original `__init__` method is called @@ -156,9 +152,7 @@ def _to_spec_helper(v): return v def to_spec(self) -> Spec: - """ - Generate a spec from a `Specifiable` subclass object. - """ + """Generate a spec from a `Specifiable` subclass object.""" if getattr(type(self), 'spec_type', None) is None: raise ValueError( f"'{type(self).__name__}' not registered as Specifiable. " @@ -168,6 +162,10 @@ def to_spec(self) -> Spec: return Spec(type=self.__class__.spec_type, config=args) + def run_original_init(self) -> None: + """Invoke the original __init__ method with original keyword arguments""" + pass + # Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE` def _register(cls, spec_type=None) -> None: @@ -209,13 +207,13 @@ def specifiable( implementing the `Specifiable` protocol. To use the decorator, simply place `@specifiable` before the class - definition.:: + definition:: @specifiable class Foo(): ... - For finer control, the decorator can accept arguments.:: + For finer control, the decorator can accept arguments:: @specifiable(spec_type="My Class", on_demand_init=False) class Bar():