diff --git a/sdks/python/apache_beam/ml/anomaly/__init__.py b/sdks/python/apache_beam/ml/anomaly/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py new file mode 100644 index 000000000000..6a717cf5db16 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -0,0 +1,211 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Base classes for anomaly detection +""" +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam + +__all__ = [ + "AnomalyPrediction", + "AnomalyResult", + "ThresholdFn", + "AggregationFn", + "AnomalyDetector", + "EnsembleAnomalyDetector" +] + + +@dataclass(frozen=True) +class AnomalyPrediction(): + """A dataclass for anomaly detection predictions.""" + #: The ID of detector (model) that generates the prediction. + model_id: Optional[str] = None + #: The outlier score resulting from applying the detector to the input data. + score: Optional[float] = None + #: The outlier label (normal or outlier) derived from the outlier score. + label: Optional[int] = None + #: The threshold used to determine the label. + threshold: Optional[float] = None + #: Additional information about the prediction. + info: str = "" + #: If enabled, a list of `AnomalyPrediction` objects used to derive the + #: aggregated prediction. + agg_history: Optional[Iterable[AnomalyPrediction]] = None + + +@dataclass(frozen=True) +class AnomalyResult(): + """A dataclass for the anomaly detection results""" + #: The original input data. + example: beam.Row + #: The `AnomalyPrediction` object containing the prediction. + prediction: AnomalyPrediction + + +class ThresholdFn(abc.ABC): + """An abstract base class for threshold functions. + + Args: + normal_label: The integer label used to identify normal data. Defaults to 0. + outlier_label: The integer label used to identify outlier data. Defaults to + 1. + """ + def __init__(self, normal_label: int = 0, outlier_label: int = 1): + self._normal_label = normal_label + self._outlier_label = outlier_label + + @property + @abc.abstractmethod + def is_stateful(self) -> bool: + """Indicates whether the threshold function is stateful or not.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def threshold(self) -> Optional[float]: + """Retrieves the current threshold value, or None if not set.""" + raise NotImplementedError + + @abc.abstractmethod + def apply(self, score: Optional[float]) -> int: + """Applies the threshold function to a given score to classify it as + normal or outlier. + + Args: + score: The outlier score generated from the detector (model). + + Returns: + The label assigned to the score, either `self._normal_label` + or `self._outlier_label` + """ + raise NotImplementedError + + +class AggregationFn(abc.ABC): + """An abstract base class for aggregation functions.""" + @abc.abstractmethod + def apply( + self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction: + """Applies the aggregation function to an iterable of predictions, either on + their outlier scores or labels. + + Args: + predictions: An Iterable of `AnomalyPrediction` objects to aggregate. + + Returns: + An `AnomalyPrediction` object containing the aggregated result. + """ + raise NotImplementedError + + +class AnomalyDetector(abc.ABC): + """An abstract base class for anomaly detectors. + + Args: + model_id: The ID of detector (model). Defaults to the value of the + `spec_type` attribute, or 'unknown' if not set. + features: An Iterable of strings representing the names of the input + features in the `beam.Row` + target: The name of the target field in the `beam.Row`. + threshold_criterion: An optional `ThresholdFn` to apply to the outlier score + and yield a label. + """ + def __init__( + self, + model_id: Optional[str] = None, + features: Optional[Iterable[str]] = None, + target: Optional[str] = None, + threshold_criterion: Optional[ThresholdFn] = None, + **kwargs): + self._model_id = model_id if model_id is not None else getattr( + self, 'spec_type', 'unknown') + self._features = features + self._target = target + self._threshold_criterion = threshold_criterion + + @abc.abstractmethod + def learn_one(self, x: beam.Row) -> None: + """Trains the detector on a single data instance. + + Args: + x: A `beam.Row` representing the data instance. + """ + raise NotImplementedError + + @abc.abstractmethod + def score_one(self, x: beam.Row) -> float: + """Scores a single data instance for anomalies. + + Args: + x: A `beam.Row` representing the data instance. + + Returns: + The outlier score as a float. + """ + raise NotImplementedError + + +class EnsembleAnomalyDetector(AnomalyDetector): + """An abstract base class for an ensemble of anomaly (sub-)detectors. + + Args: + sub_detectors: A List of `AnomalyDetector` used in this ensemble model. + aggregation_strategy: An optional `AggregationFn` to apply to the + predictions from all sub-detectors and yield an aggregated result. + model_id: Inherited from `AnomalyDetector`. + features: Inherited from `AnomalyDetector`. + target: Inherited from `AnomalyDetector`. + threshold_criterion: Inherited from `AnomalyDetector`. + """ + def __init__( + self, + sub_detectors: Optional[List[AnomalyDetector]] = None, + aggregation_strategy: Optional[AggregationFn] = None, + **kwargs): + if "model_id" not in kwargs or kwargs["model_id"] is None: + kwargs["model_id"] = getattr(self, 'spec_type', 'custom') + + super().__init__(**kwargs) + + self._aggregation_strategy = aggregation_strategy + self._sub_detectors = sub_detectors + + def learn_one(self, x: beam.Row) -> None: + """Inherited from `AnomalyDetector.learn_one`. + + This method is never called during ensemble detector training. The training + process is done on each sub-detector independently and in parallel. + """ + raise NotImplementedError + + def score_one(self, x: beam.Row) -> float: + """Inherited from `AnomalyDetector.score_one`. + + This method is never called during ensemble detector scoring. The scoring + process is done on sub-detector independently and in parallel, and then + the results are aggregated in the pipeline. + """ + raise NotImplementedError diff --git a/sdks/python/apache_beam/ml/anomaly/base_test.py b/sdks/python/apache_beam/ml/anomaly/base_test.py new file mode 100644 index 000000000000..e58674d8c1e9 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/base_test.py @@ -0,0 +1,241 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import logging +import unittest + +from parameterized import parameterized + +from apache_beam.ml.anomaly.base import AggregationFn +from apache_beam.ml.anomaly.base import AnomalyDetector +from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector +from apache_beam.ml.anomaly.base import ThresholdFn +from apache_beam.ml.anomaly.specifiable import _KNOWN_SPECIFIABLE +from apache_beam.ml.anomaly.specifiable import Spec +from apache_beam.ml.anomaly.specifiable import Specifiable +from apache_beam.ml.anomaly.specifiable import specifiable + + +class TestAnomalyDetector(unittest.TestCase): + def setUp(self) -> None: + # Remove all registered specifiable classes and reset. + _KNOWN_SPECIFIABLE.clear() + + @parameterized.expand([(False, False), (True, False), (False, True), + (True, True)]) + def test_model_id_and_spec(self, on_demand_init, just_in_time_init): + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + class DummyThreshold(ThresholdFn): + def __init__(self, my_threshold_arg=None): + ... + + def is_stateful(self): + return False + + def threshold(self): + ... + + def apply(self, x): + ... + + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + class Dummy(AnomalyDetector): + def __init__(self, my_arg=None, **kwargs): + self._my_arg = my_arg + super().__init__(**kwargs) + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value) -> bool: + return isinstance(value, Dummy) and \ + self._my_arg == value._my_arg + + a = Dummy( + my_arg="abc", + target="ABC", + threshold_criterion=(t1 := DummyThreshold(2))) + + # The class attributes can only be accessed when + # (1) on_demand_init == False, just_in_time_init == False + # In this case, the true __init__ is called immediately during object + # initialization + # (2) just_in_time_init == True + # In this case, regardless what on_demand_init is, the true __init__ + # is called when we first access any class attribute (delay init). + if just_in_time_init or not on_demand_init: + self.assertEqual(a._model_id, "Dummy") + self.assertEqual(a._target, "ABC") + self.assertEqual(a._my_arg, "abc") + + assert isinstance(a, Specifiable) + self.assertEqual( + a.init_kwargs, { + "my_arg": "abc", + "target": "ABC", + "threshold_criterion": t1, + }) + + b = Dummy( + my_arg="efg", + model_id="my_dummy", + target="EFG", + threshold_criterion=(t2 := DummyThreshold(3))) + + # See the comment above for more details. + if just_in_time_init or not on_demand_init: + self.assertEqual(b._model_id, "my_dummy") + self.assertEqual(b._target, "EFG") + self.assertEqual(b._my_arg, "efg") + + assert isinstance(b, Specifiable) + self.assertEqual( + b.init_kwargs, + { + "model_id": "my_dummy", + "my_arg": "efg", + "target": "EFG", + "threshold_criterion": t2, + }) + + spec = b.to_spec() + expected_spec = Spec( + type="Dummy", + config={ + "my_arg": "efg", + "model_id": "my_dummy", + "target": "EFG", + "threshold_criterion": Spec( + type="DummyThreshold", config={"my_threshold_arg": 3}), + }) + self.assertEqual(spec, expected_spec) + + b_dup = Specifiable.from_spec(spec) + # We need to manually call the original __init__ function in one scenario. + # See the comment above for more details. + if on_demand_init and not just_in_time_init: + b.run_original_init() + + self.assertEqual(b, b_dup) + + +class TestEnsembleAnomalyDetector(unittest.TestCase): + def setUp(self) -> None: + # Remove all registered specifiable classes and reset. + _KNOWN_SPECIFIABLE.clear() + + @parameterized.expand([(False, False), (True, False), (False, True), + (True, True)]) + def test_model_id_and_spec(self, on_demand_init, just_in_time_init): + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + class DummyAggregation(AggregationFn): + def apply(self, x): + ... + + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + class DummyEnsemble(EnsembleAnomalyDetector): + def __init__(self, my_ensemble_arg=None, **kwargs): + super().__init__(**kwargs) + self._my_ensemble_arg = my_ensemble_arg + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value) -> bool: + return isinstance(value, DummyEnsemble) and \ + self._my_ensemble_arg == value._my_ensemble_arg + + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + class DummyWeakLearner(AnomalyDetector): + def __init__(self, my_arg=None, **kwargs): + super().__init__(**kwargs) + self._my_arg = my_arg + + def learn_one(self): + ... + + def score_one(self): + ... + + def __eq__(self, value) -> bool: + return isinstance(value, DummyWeakLearner) \ + and self._my_arg == value._my_arg + + # See the comment in TestAnomalyDetector for more details. + if just_in_time_init or not on_demand_init: + a = DummyEnsemble() + self.assertEqual(a._model_id, "DummyEnsemble") + + b = DummyEnsemble(model_id="my_dummy_ensemble") + self.assertEqual(b._model_id, "my_dummy_ensemble") + + c = EnsembleAnomalyDetector() + self.assertEqual(c._model_id, "custom") + + d = EnsembleAnomalyDetector(model_id="my_dummy_ensemble_2") + self.assertEqual(d._model_id, "my_dummy_ensemble_2") + + d1 = DummyWeakLearner(my_arg=1) + d2 = DummyWeakLearner(my_arg=2) + ensemble = DummyEnsemble( + my_ensemble_arg=123, + learners=[d1, d2], + aggregation_strategy=DummyAggregation()) + + expected_spec = Spec( + type="DummyEnsemble", + config={ + "my_ensemble_arg": 123, + "learners": [ + Spec(type="DummyWeakLearner", config={"my_arg": 1}), + Spec(type="DummyWeakLearner", config={"my_arg": 2}) + ], + "aggregation_strategy": Spec( + type="DummyAggregation", + config={}, + ), + }) + + assert isinstance(ensemble, Specifiable) + spec = ensemble.to_spec() + self.assertEqual(spec, expected_spec) + + ensemble_dup = Specifiable.from_spec(spec) + + # See the comment in TestAnomalyDetector for more details. + if on_demand_init and not just_in_time_init: + ensemble.run_original_init() + + self.assertEqual(ensemble, ensemble_dup) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py new file mode 100644 index 000000000000..1aedab2e8c21 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -0,0 +1,326 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A module that provides utilities to turn a class into a Specifiable subclass. +""" + +from __future__ import annotations + +import collections +import dataclasses +import inspect +import logging +from typing import Any +from typing import ClassVar +from typing import List +from typing import Protocol +from typing import Type +from typing import TypeVar +from typing import runtime_checkable + +from typing_extensions import Self + +__all__ = ["Spec", "Specifiable", "specifiable"] + +_FALLBACK_SUBSPACE = "*" + +_ACCEPTED_SUBSPACES = [ + "EnsembleAnomalyDetector", + "AnomalyDetector", + "ThresholdFn", + "AggregationFn", + _FALLBACK_SUBSPACE, +] + +#: A nested dictionary for efficient lookup of Specifiable subclasses. +#: Structure: `_KNOWN_SPECIFIABLE[subspace][spec_type]`, where `subspace` is one +#: of the accepted subspaces that the class belongs to and `spec_type` is the +#: class name by default. Users can also specify a different value for +#: `spec_type` when applying the `specifiable` decorator to an existing class. +_KNOWN_SPECIFIABLE = collections.defaultdict(dict) + +SpecT = TypeVar('SpecT', bound='Specifiable') + + +def _class_to_subspace(cls: Type) -> str: + """ + Search the class hierarchy to find the subspace: the closest ancestor class in + the class's method resolution order (MRO) whose name is found in the accepted + subspace list. This is usually called when registering a new specifiable + class. + """ + for c in cls.mro(): + if c.__name__ in _ACCEPTED_SUBSPACES: + return c.__name__ + + return _FALLBACK_SUBSPACE + + +def _spec_type_to_subspace(spec_type: str) -> str: + """ + Look for the subspace for a spec type. This is usually called to retrieve + the subspace of a registered specifiable class. + """ + for subspace in _ACCEPTED_SUBSPACES: + if spec_type in _KNOWN_SPECIFIABLE[subspace]: + return subspace + + raise ValueError(f"subspace for {str} not found.") + + +@dataclasses.dataclass(frozen=True) +class Spec(): + """ + Dataclass for storing specifications of specifiable objects. + Objects can be initialized using the data in their corresponding spec. + """ + #: A string indicating the concrete `Specifiable` class + type: str + #: A dictionary of keyword arguments for the `__init__` method of the class. + config: dict[str, Any] = dataclasses.field(default_factory=dict) + + +@runtime_checkable +class Specifiable(Protocol): + """Protocol that a specifiable class needs to implement.""" + #: The value of the `type` field in the object's spec for this class. + spec_type: ClassVar[str] + #: The raw keyword arguments passed to `__init__` method during object + #: initialization. + init_kwargs: dict[str, Any] + + # a boolean to tell whether the original `__init__` method is called + _initialized: bool + # a boolean used by new_getattr to tell whether it is in the `__init__` method + # call + _in_init: bool + + @staticmethod + def _from_spec_helper(v, _run_init): + if isinstance(v, Spec): + return Specifiable.from_spec(v, _run_init) + + if isinstance(v, List): + return [Specifiable._from_spec_helper(e, _run_init) for e in v] + + return v + + @classmethod + def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: + """Generate a `Specifiable` subclass object based on a spec.""" + if spec.type is None: + raise ValueError(f"Spec type not found in {spec}") + + subspace = _spec_type_to_subspace(spec.type) + subclass: Type[Self] = _KNOWN_SPECIFIABLE[subspace].get(spec.type, None) + if subclass is None: + raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") + + kwargs = { + k: Specifiable._from_spec_helper(v, _run_init) + for k, + v in spec.config.items() + } + + if _run_init: + kwargs["_run_init"] = True + return subclass(**kwargs) + + @staticmethod + def _to_spec_helper(v): + if isinstance(v, Specifiable): + return v.to_spec() + + if isinstance(v, List): + return [Specifiable._to_spec_helper(e) for e in v] + + return v + + def to_spec(self) -> Spec: + """Generate a spec from a `Specifiable` subclass object.""" + if getattr(type(self), 'spec_type', None) is None: + raise ValueError( + f"'{type(self).__name__}' not registered as Specifiable. " + f"Decorate ({type(self).__name__}) with @specifiable") + + args = {k: self._to_spec_helper(v) for k, v in self.init_kwargs.items()} + + return Spec(type=self.__class__.spec_type, config=args) + + def run_original_init(self) -> None: + """Invoke the original __init__ method with original keyword arguments""" + pass + + +# Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE` +def _register(cls, spec_type=None) -> None: + if spec_type is None: + # By default, spec type is the class name. Users can override this with + # other unique identifier. + spec_type = cls.__name__ + + subspace = _class_to_subspace(cls) + if spec_type in _KNOWN_SPECIFIABLE[subspace]: + raise ValueError( + f"{spec_type} is already registered for " + f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. " + "Please specify a different spec_type by @specifiable(spec_type=...).") + else: + _KNOWN_SPECIFIABLE[subspace][spec_type] = cls + + cls.spec_type = spec_type + + +# Keep a copy of arguments that are used to call the `__init__` method when the +# object is initialized. +def _get_init_kwargs(inst, init_method, *args, **kwargs): + params = dict( + zip(inspect.signature(init_method).parameters.keys(), (None, ) + args)) + del params['self'] + params.update(**kwargs) + return params + + +def specifiable( + my_cls=None, + /, + *, + spec_type=None, + on_demand_init=True, + just_in_time_init=True): + """A decorator that turns a class into a `Specifiable` subclass by + implementing the `Specifiable` protocol. + + To use the decorator, simply place `@specifiable` before the class + definition:: + + @specifiable + class Foo(): + ... + + For finer control, the decorator can accept arguments:: + + @specifiable(spec_type="My Class", on_demand_init=False) + class Bar(): + ... + + Args: + spec_type: The value of the `type` field in the Spec of a `Specifiable` + subclass. If not provided, the class name is used. This argument is useful + when registering multiple classes with the same base name; in such cases, + one can specify `spec_type` to different values to resolve conflict. + on_demand_init: If True, allow on-demand object initialization. The original + `__init__` method will be called when `_run_init=True` is passed to the + object's initialization function. + just_in_time_init: If True, allow just-in-time object initialization. The + original `__init__` method will be called when the first time an attribute + is accessed. + """ + def _wrapper(cls): + def new_init(self: Specifiable, *args, **kwargs): + self._initialized = False + self._in_init = False + + run_init_request = False + if "_run_init" in kwargs: + run_init_request = kwargs["_run_init"] + del kwargs["_run_init"] + + if 'init_kwargs' not in self.__dict__: + # If it is a child specifiable (i.e.g init_kwargs not set), we determine + # whether to skip the original __init__ call based on options: + # on_demand_init, just_in_time_init and _run_init. + # Otherwise (i.e. init_kwargs is set), we always call the original + # __init__ method for ancestor specifiable. + self.init_kwargs = _get_init_kwargs( + self, original_init, *args, **kwargs) + logging.debug("Record init params in %s.new_init", class_name) + + if (on_demand_init and not run_init_request) or \ + (not on_demand_init and just_in_time_init): + logging.debug("Skip original %s.__init__", class_name) + return + + logging.debug("Call original %s.__init__ in new_init", class_name) + + original_init(self, *args, **kwargs) + self._initialized = True + + def run_original_init(self): + self._in_init = True + original_init(self, **self.init_kwargs) + self._in_init = False + self._initialized = True + + # __getattr__ is only called when an attribute is not found in the object + def new_getattr(self, name): + logging.debug( + "Trying to access %s.%s, but it is not found.", class_name, name) + + # Fix the infinite loop issue when pickling a Specifiable + if name in ["_in_init", "__getstate__"] and name not in self.__dict__: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") + + # If the attribute is not found during or after initialization, then + # it is a missing attribute. + if self._in_init or self._initialized: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") + + # Here, we know the object is not initialized, then we will call original + # init method. + logging.debug("Call original %s.__init__ in new_getattr", class_name) + run_original_init(self) + + # __getattribute__ is call for every attribute regardless whether it is + # present in the object. In this case, we don't cause an infinite loop + # if the attribute does not exist. + logging.debug( + "Call original %s.__getattribute__(%s) in new_getattr", + class_name, + name) + return self.__getattribute__(name) + + # start of the function body of _wrapper + _register(cls, spec_type) + + class_name = cls.__name__ + original_init = cls.__init__ + cls.__init__ = new_init + if just_in_time_init: + cls.__getattr__ = new_getattr + + cls.run_original_init = run_original_init + cls.to_spec = Specifiable.to_spec + cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) + cls.from_spec = classmethod(Specifiable.from_spec) + cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) + return cls + # end of the function body of _wrapper + + # When this decorator is called with arguments, i.e.. + # "@specifiable(arg1=...,arg2=...)", it is equivalent to assigning + # specifiable(arg1=..., arg2=...) to a variable, say decor_func, and then + # calling "@decor_func". + if my_cls is None: + return _wrapper + + # When this decorator is called without an argument, i.e. "@specifiable", + # we return the augmented class. + return _wrapper(my_cls) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py new file mode 100644 index 000000000000..19b9d81c3d53 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -0,0 +1,471 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import dataclasses +import logging +import unittest +from typing import List +from typing import Optional + +from parameterized import parameterized + +from apache_beam.ml.anomaly.specifiable import _FALLBACK_SUBSPACE +from apache_beam.ml.anomaly.specifiable import _KNOWN_SPECIFIABLE +from apache_beam.ml.anomaly.specifiable import Spec +from apache_beam.ml.anomaly.specifiable import Specifiable +from apache_beam.ml.anomaly.specifiable import specifiable + + +class TestSpecifiable(unittest.TestCase): + def setUp(self) -> None: + # Remove all registered specifiable classes and reset. + _KNOWN_SPECIFIABLE.clear() + + def test_decorator_in_function_form(self): + class A(): + pass + + # class is not decorated and thus not registered + self.assertNotIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + + # apply the decorator function to an existing class + A = specifiable(A) + self.assertEqual(A.spec_type, "A") + self.assertTrue(isinstance(A(), Specifiable)) + self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A) + + # an error is raised if the specified spec_type already exists. + self.assertRaises(ValueError, specifiable, A) + + # apply the decorator function to an existing class with a different + # spec_type + A = specifiable(spec_type="A_DUP")(A) + self.assertEqual(A.spec_type, "A_DUP") + self.assertTrue(isinstance(A(), Specifiable)) + self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A) + + # an error is raised if the specified spec_type already exists. + self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) + + def test_decorator_in_syntactic_sugar_form(self): + # call decorator without parameters + @specifiable + class B(): + pass + + self.assertTrue(isinstance(B(), Specifiable)) + self.assertIn("B", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["B"], B) + + # call decorator with parameters + @specifiable(spec_type="C_TYPE") + class C(): + pass + + self.assertTrue(isinstance(C(), Specifiable)) + self.assertIn("C_TYPE", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) + self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["C_TYPE"], C) + + def test_init_params_in_specifiable(self): + @specifiable + class ParentWithInitParams(): + def __init__(self, arg_1, arg_2=2, arg_3="3", **kwargs): + pass + + parent = ParentWithInitParams(10, arg_3="30", arg_4=40) + assert isinstance(parent, Specifiable) + self.assertEqual( + parent.init_kwargs, { + 'arg_1': 10, 'arg_3': '30', 'arg_4': 40 + }) + + # inheritance of a Specifiable subclass + @specifiable + class ChildWithInitParams(ParentWithInitParams): + def __init__(self, new_arg_1, new_arg_2=200, new_arg_3="300", **kwargs): + super().__init__(**kwargs) + + child = ChildWithInitParams( + 1000, arg_1=11, arg_2=20, new_arg_2=2000, arg_4=4000) + assert isinstance(child, Specifiable) + self.assertEqual( + child.init_kwargs, + { + 'new_arg_1': 1000, + 'arg_1': 11, + 'arg_2': 20, + 'new_arg_2': 2000, + 'arg_4': 4000 + }) + + # composite of Specifiable subclasses + @specifiable + class CompositeWithInitParams(): + def __init__( + self, + my_parent: Optional[ParentWithInitParams] = None, + my_child: Optional[ChildWithInitParams] = None): + pass + + composite = CompositeWithInitParams(parent, child) + assert isinstance(composite, Specifiable) + self.assertEqual( + composite.init_kwargs, { + 'my_parent': parent, 'my_child': child + }) + + def test_from_spec_on_unknown_spec_type(self): + self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown")) + + # To test from_spec and to_spec with/without just_in_time_init. + @parameterized.expand([(False, False), (True, False), (False, True), + (True, True)]) + def test_from_spec_and_to_spec(self, on_demand_init, just_in_time_init): + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + @dataclasses.dataclass + class Product(): + name: str + price: float + + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + class Entry(): + def __init__(self, product: Product, quantity: int = 1): + self._product = product + self._quantity = quantity + + def __eq__(self, value) -> bool: + return isinstance(value, Entry) and \ + self._product == value._product and \ + self._quantity == value._quantity + + @specifiable( + on_demand_init=on_demand_init, just_in_time_init=just_in_time_init) + @dataclasses.dataclass + class ShoppingCart(): + user_id: str + entries: List[Entry] + + orange = Product("orange", 1.0) + + expected_orange_spec = Spec( + "Product", config={ + 'name': 'orange', 'price': 1.0 + }) + assert isinstance(orange, Specifiable) + self.assertEqual(orange.to_spec(), expected_orange_spec) + + entry_1 = Entry(product=orange) + + expected_entry_spec_1 = Spec( + "Entry", config={ + 'product': expected_orange_spec, + }) + + assert isinstance(entry_1, Specifiable) + self.assertEqual(entry_1.to_spec(), expected_entry_spec_1) + + banana = Product("banana", 0.5) + expected_banana_spec = Spec( + "Product", config={ + 'name': 'banana', 'price': 0.5 + }) + entry_2 = Entry(product=banana, quantity=5) + expected_entry_spec_2 = Spec( + "Entry", config={ + 'product': expected_banana_spec, 'quantity': 5 + }) + + shopping_cart = ShoppingCart(user_id="test", entries=[entry_1, entry_2]) + expected_shopping_cart_spec = Spec( + "ShoppingCart", + config={ + "user_id": "test", + "entries": [expected_entry_spec_1, expected_entry_spec_2] + }) + + assert isinstance(shopping_cart, Specifiable) + self.assertEqual(shopping_cart.to_spec(), expected_shopping_cart_spec) + if on_demand_init and not just_in_time_init: + orange.run_original_init() + banana.run_original_init() + entry_1.run_original_init() + entry_2.run_original_init() + shopping_cart.run_original_init() + + self.assertEqual(Specifiable.from_spec(expected_orange_spec), orange) + self.assertEqual(Specifiable.from_spec(expected_entry_spec_1), entry_1) + self.assertEqual( + Specifiable.from_spec(expected_shopping_cart_spec), shopping_cart) + + +class TestInitCallCount(unittest.TestCase): + def test_on_demand_init(self): + @specifiable(on_demand_init=True, just_in_time_init=False) + class FooOnDemand(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + FooOnDemand.counter += 1 # increment it when __init__ is called + + foo = FooOnDemand(123) + self.assertEqual(FooOnDemand.counter, 0) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 123}) + + self.assertNotIn("my_arg", foo.__dict__) + self.assertRaises(AttributeError, getattr, foo, "my_arg") + self.assertRaises(AttributeError, lambda: foo.my_arg) + self.assertRaises(AttributeError, getattr, foo, "unknown_arg") + self.assertRaises(AttributeError, lambda: foo.unknown_arg) + self.assertEqual(FooOnDemand.counter, 0) + + # __init__ is called when _run_init=True is used + foo_2 = FooOnDemand(456, _run_init=True) + self.assertEqual(FooOnDemand.counter, 1) + self.assertIn("init_kwargs", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["init_kwargs"], {"arg": 456}) + + self.assertIn("my_arg", foo_2.__dict__) + self.assertEqual(foo_2.my_arg, 4560) + self.assertEqual(FooOnDemand.counter, 1) + + def test_just_in_time_init(self): + @specifiable(on_demand_init=False, just_in_time_init=True) + class FooJustInTime(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + FooJustInTime.counter += 1 # increment it when __init__ is called + + foo = FooJustInTime(321) + self.assertEqual(FooJustInTime.counter, 0) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 321}) + + # __init__ hasn't been called yet + self.assertNotIn("my_arg", foo.__dict__) + self.assertEqual(FooJustInTime.counter, 0) + + # __init__ is called when trying to access a class attribute + self.assertEqual(foo.my_arg, 3210) + self.assertEqual(FooJustInTime.counter, 1) + self.assertRaises(AttributeError, lambda: foo.unknown_arg) + self.assertEqual(FooJustInTime.counter, 1) + + def test_on_demand_and_just_in_time_init(self): + @specifiable(on_demand_init=True, just_in_time_init=True) + class FooOnDemandAndJustInTime(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + FooOnDemandAndJustInTime.counter += 1 + + foo = FooOnDemandAndJustInTime(987) + self.assertEqual(FooOnDemandAndJustInTime.counter, 0) + self.assertIn("init_kwargs", foo.__dict__) + self.assertEqual(foo.__dict__["init_kwargs"], {"arg": 987}) + self.assertNotIn("my_arg", foo.__dict__) + + self.assertEqual(FooOnDemandAndJustInTime.counter, 0) + # __init__ is called when trying to access a class attribute + self.assertEqual(foo.my_arg, 9870) + self.assertEqual(FooOnDemandAndJustInTime.counter, 1) + + # __init__ is called when _run_init=True is used + foo_2 = FooOnDemandAndJustInTime(789, _run_init=True) + self.assertEqual(FooOnDemandAndJustInTime.counter, 2) + self.assertIn("init_kwargs", foo_2.__dict__) + self.assertEqual(foo_2.__dict__["init_kwargs"], {"arg": 789}) + + self.assertEqual(FooOnDemandAndJustInTime.counter, 2) + # __init__ is NOT called after it is initialized + self.assertEqual(foo_2.my_arg, 7890) + self.assertEqual(FooOnDemandAndJustInTime.counter, 2) + + @specifiable(on_demand_init=True, just_in_time_init=True) + class FooForPickle(): + counter = 0 + + def __init__(self, arg): + self.my_arg = arg * 10 + type(self).counter += 1 + + def test_on_pickle(self): + FooForPickle = TestInitCallCount.FooForPickle + + import dill + FooForPickle.counter = 0 + foo = FooForPickle(456) + self.assertEqual(FooForPickle.counter, 0) + new_foo = dill.loads(dill.dumps(foo)) + self.assertEqual(FooForPickle.counter, 0) + self.assertEqual(new_foo.__dict__, foo.__dict__) + self.assertEqual(foo.my_arg, 4560) + self.assertEqual(FooForPickle.counter, 1) + new_foo_2 = dill.loads(dill.dumps(foo)) + self.assertEqual(FooForPickle.counter, 1) + self.assertEqual(new_foo_2.__dict__, foo.__dict__) + + # Note that pickle does not support classes/functions nested in a function. + import pickle + FooForPickle.counter = 0 + foo = FooForPickle(456) + self.assertEqual(FooForPickle.counter, 0) + new_foo = pickle.loads(pickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 0) + self.assertEqual(new_foo.__dict__, foo.__dict__) + self.assertEqual(foo.my_arg, 4560) + self.assertEqual(FooForPickle.counter, 1) + new_foo_2 = pickle.loads(pickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 1) + self.assertEqual(new_foo_2.__dict__, foo.__dict__) + + import cloudpickle + FooForPickle.counter = 0 + foo = FooForPickle(456) + self.assertEqual(FooForPickle.counter, 0) + new_foo = cloudpickle.loads(cloudpickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 0) + self.assertEqual(new_foo.__dict__, foo.__dict__) + self.assertEqual(foo.my_arg, 4560) + self.assertEqual(FooForPickle.counter, 1) + new_foo_2 = cloudpickle.loads(cloudpickle.dumps(foo)) + self.assertEqual(FooForPickle.counter, 1) + self.assertEqual(new_foo_2.__dict__, foo.__dict__) + + +@specifiable +class Parent(): + counter = 0 + parent_class_var = 1000 + + def __init__(self, p): + self.parent_inst_var = p * 10 + Parent.counter += 1 + + +@specifiable +class Child_1(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + super().__init__(c) + self.child_inst_var = c + 1 + Child_1.counter += 1 + + +@specifiable +class Child_2(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + self.child_inst_var = c + 1 + super().__init__(c) + Child_2.counter += 1 + + +@specifiable +class Child_Error_1(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + # read an instance var in child that doesn't exist + self.child_inst_var += 1 + super().__init__(c) + Child_2.counter += 1 + + +@specifiable +class Child_Error_2(Parent): + counter = 0 + child_class_var = 2001 + + def __init__(self, c): + # read an instance var in parent without calling parent's __init__. + self.parent_inst_var += 1 + Child_2.counter += 1 + + +class TestNestedSpecifiable(unittest.TestCase): + @parameterized.expand([[Child_1, 0], [Child_2, 0], [Child_1, 1], [Child_2, 1], + [Child_1, 2], [Child_2, 2]]) + def test_nested_specifiable(self, Child, mode): + Parent.counter = 0 + Child.counter = 0 + child = Child(5) + + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child.counter, 0) + + # accessing class vars won't trigger __init__ + self.assertEqual(child.parent_class_var, 1000) + self.assertEqual(child.child_class_var, 2001) + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child.counter, 0) + + # accessing instance var will trigger __init__ + if mode == 0: + self.assertEqual(child.parent_inst_var, 50) + elif mode == 1: + self.assertEqual(child.child_inst_var, 6) + else: + self.assertRaises(AttributeError, lambda: child.unknown_var) + + self.assertEqual(Parent.counter, 1) + self.assertEqual(Child.counter, 1) + + # after initialization, it won't trigger __init__ again + self.assertEqual(child.parent_inst_var, 50) + self.assertEqual(child.child_inst_var, 6) + self.assertRaises(AttributeError, lambda: child.unknown_var) + + self.assertEqual(Parent.counter, 1) + self.assertEqual(Child.counter, 1) + + def test_error_in_child(self): + Parent.counter = 0 + child_1 = Child_Error_1(5) + + self.assertEqual(child_1.child_class_var, 2001) + + # error during child initialization + self.assertRaises(AttributeError, lambda: child_1.child_inst_var) + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child_1.counter, 0) + + child_2 = Child_Error_2(5) + self.assertEqual(child_2.child_class_var, 2001) + + # error during child initialization + self.assertRaises(AttributeError, lambda: child_2.parent_inst_var) + self.assertEqual(Parent.counter, 0) + self.assertEqual(Child_2.counter, 0) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()