diff --git a/sdks/python/apache_beam/ml/anomaly/aggregations.py b/sdks/python/apache_beam/ml/anomaly/aggregations.py index 832f28316502..6d9f3797663b 100644 --- a/sdks/python/apache_beam/ml/anomaly/aggregations.py +++ b/sdks/python/apache_beam/ml/anomaly/aggregations.py @@ -23,6 +23,9 @@ from typing import Iterable from typing import Optional +from apache_beam.ml.anomaly.base import DEFAULT_MISSING_LABEL +from apache_beam.ml.anomaly.base import DEFAULT_NORMAL_LABEL +from apache_beam.ml.anomaly.base import DEFAULT_OUTLIER_LABEL from apache_beam.ml.anomaly.base import AggregationFn from apache_beam.ml.anomaly.base import AnomalyPrediction from apache_beam.ml.anomaly.specifiable import specifiable @@ -69,9 +72,13 @@ def __init__( agg_func: Callable[[Iterable[int]], int], agg_model_id: Optional[str] = None, include_source_predictions: bool = False, - missing_label: int = -2, + normal_label: int = DEFAULT_NORMAL_LABEL, + outlier_label: int = DEFAULT_OUTLIER_LABEL, + missing_label: int = DEFAULT_MISSING_LABEL, ): self._agg = agg_func + self._normal_label = normal_label + self._outlier_label = outlier_label self._missing_label = missing_label _AggModelIdMixin.__init__(self, agg_model_id) _SourcePredictionMixin.__init__(self, include_source_predictions) @@ -208,10 +215,8 @@ class MajorityVote(LabelAggregation): **kwargs: Additional keyword arguments to pass to the base `LabelAggregation` class. """ - def __init__(self, normal_label=0, outlier_label=1, tie_breaker=0, **kwargs): + def __init__(self, tie_breaker=DEFAULT_NORMAL_LABEL, **kwargs): self._tie_breaker = tie_breaker - self._normal_label = normal_label - self._outlier_label = outlier_label def inner(predictions: Iterable[int]) -> int: counters = collections.Counter(predictions) @@ -248,10 +253,7 @@ class AllVote(LabelAggregation): **kwargs: Additional keyword arguments to pass to the base `LabelAggregation` class. """ - def __init__(self, normal_label=0, outlier_label=1, **kwargs): - self._normal_label = normal_label - self._outlier_label = outlier_label - + def __init__(self, **kwargs): def inner(predictions: Iterable[int]) -> int: return self._outlier_label if all( map(lambda p: p == self._outlier_label, @@ -282,10 +284,7 @@ class AnyVote(LabelAggregation): **kwargs: Additional keyword arguments to pass to the base `LabelAggregation` class. """ - def __init__(self, normal_label=0, outlier_label=1, **kwargs): - self._normal_label = normal_label - self._outlier_label = outlier_label - + def __init__(self, **kwargs): def inner(predictions: Iterable[int]) -> int: return self._outlier_label if any( map(lambda p: p == self._outlier_label, diff --git a/sdks/python/apache_beam/ml/anomaly/base.py b/sdks/python/apache_beam/ml/anomaly/base.py index 5886c7278119..e3c6252474bd 100644 --- a/sdks/python/apache_beam/ml/anomaly/base.py +++ b/sdks/python/apache_beam/ml/anomaly/base.py @@ -37,6 +37,10 @@ "EnsembleAnomalyDetector" ] +DEFAULT_NORMAL_LABEL = 0 +DEFAULT_OUTLIER_LABEL = 1 +DEFAULT_MISSING_LABEL = -2 + @dataclass(frozen=True) class AnomalyPrediction(): @@ -79,9 +83,9 @@ class ThresholdFn(abc.ABC): """ def __init__( self, - normal_label: int = 0, - outlier_label: int = 1, - missing_label: int = -2): + normal_label: int = DEFAULT_NORMAL_LABEL, + outlier_label: int = DEFAULT_OUTLIER_LABEL, + missing_label: int = DEFAULT_MISSING_LABEL): self._normal_label = normal_label self._outlier_label = outlier_label self._missing_label = missing_label diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index e0122d41d9d5..e73ef5513b64 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -123,7 +123,15 @@ def _from_spec_helper(v, _run_init): @classmethod def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: - """Generate a `Specifiable` subclass object based on a spec.""" + """Generate a `Specifiable` subclass object based on a spec. + + Args: + spec: the specification of a `Specifiable` subclass object + _run_init: whether to call `__init__` or not for the initial instantiation + + Returns: + Self: the `Specifiable` subclass object + """ if spec.type is None: raise ValueError(f"Spec type not found in {spec}") @@ -153,7 +161,11 @@ def _to_spec_helper(v): return v def to_spec(self) -> Spec: - """Generate a spec from a `Specifiable` subclass object.""" + """Generate a spec from a `Specifiable` subclass object. + + Returns: + Spec: The specification of the instance. + """ if getattr(type(self), 'spec_type', None) is None: raise ValueError( f"'{type(self).__name__}' not registered as Specifiable. " @@ -262,7 +274,14 @@ def new_init(self: Specifiable, *args, **kwargs): original_init(self, *args, **kwargs) self._initialized = True - def run_original_init(self): + def run_original_init(self) -> None: + """Execute the original `__init__` method with its saved arguments. + + For instances of the `Specifiable` class, initialization is deferred + (lazy initialization). This function forces the execution of the + original `__init__` method using the arguments captured during + the object's initial instantiation. + """ self._in_init = True original_init(self, **self.init_kwargs) self._in_init = False diff --git a/sdks/python/apache_beam/ml/anomaly/thresholds.py b/sdks/python/apache_beam/ml/anomaly/thresholds.py index d777aa5cde00..8226a27b7c4e 100644 --- a/sdks/python/apache_beam/ml/anomaly/thresholds.py +++ b/sdks/python/apache_beam/ml/anomaly/thresholds.py @@ -17,172 +17,13 @@ from __future__ import annotations -import dataclasses import math -from typing import Any -from typing import Iterable from typing import Optional -from typing import Tuple -from typing import Union -import apache_beam as beam -from apache_beam.coders import DillCoder -from apache_beam.ml.anomaly.base import AnomalyResult from apache_beam.ml.anomaly.base import ThresholdFn -from apache_beam.ml.anomaly.specifiable import Spec -from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.specifiable import specifiable from apache_beam.ml.anomaly.univariate.quantile import BufferedSlidingQuantileTracker # pylint: disable=line-too-long from apache_beam.ml.anomaly.univariate.quantile import QuantileTracker -from apache_beam.transforms.userstate import ReadModifyWriteRuntimeState -from apache_beam.transforms.userstate import ReadModifyWriteStateSpec - - -class BaseThresholdDoFn(beam.DoFn): - """Applies a ThresholdFn to anomaly detection results. - - This abstract base class defines the structure for DoFns that use a - `ThresholdFn` to convert anomaly scores into anomaly labels (e.g., normal - or outlier). It handles the core logic of applying the threshold function - and updating the prediction labels within `AnomalyResult` objects. - - Args: - threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be - used. - """ - def __init__(self, threshold_fn_spec: Spec): - self._threshold_fn_spec = threshold_fn_spec - self._threshold_fn = None - - def _apply_threshold_to_predictions( - self, result: AnomalyResult) -> AnomalyResult: - """Updates the prediction labels in an AnomalyResult using the ThresholdFn. - - Args: - result (AnomalyResult): The input `AnomalyResult` containing anomaly - scores. - - Returns: - AnomalyResult: A new `AnomalyResult` with updated prediction labels - and threshold values. - """ - predictions = [ - dataclasses.replace( - p, - label=self._threshold_fn.apply(p.score), - threshold=self._threshold_fn.threshold) for p in result.predictions - ] - return dataclasses.replace(result, predictions=predictions) - - -class StatelessThresholdDoFn(BaseThresholdDoFn): - """Applies a stateless ThresholdFn to anomaly detection results. - - This DoFn is designed for stateless `ThresholdFn` implementations. It - initializes the `ThresholdFn` once during setup and applies it to each - incoming element without maintaining any state across elements. - - Args: - threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be - used. - - Raises: - AssertionError: If the provided `threshold_fn_spec` leads to the - creation of a stateful `ThresholdFn`. - """ - def __init__(self, threshold_fn_spec: Spec): - threshold_fn_spec.config["_run_init"] = True - self._threshold_fn: Any = Specifiable.from_spec(threshold_fn_spec) - assert not self._threshold_fn.is_stateful, \ - "This DoFn can only take stateless function as threshold_fn" - - def process(self, element: Tuple[Any, Tuple[Any, AnomalyResult]], - **kwargs) -> Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: - """Processes a batch of anomaly results using a stateless ThresholdFn. - - Args: - element (Tuple[Any, Tuple[Any, AnomalyResult]]): A tuple representing - an element in the Beam pipeline. It is expected to be in the format - `(key1, (key2, AnomalyResult))`, where key1 is the original input key, - and key2 is a disambiguating key for distinct data points. - **kwargs: Additional keyword arguments passed to the `process` method - in Beam DoFns. - - Yields: - Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: An iterable containing - a single output element with the same structure as the input, but with - the `AnomalyResult` having updated prediction labels based on the - stateless `ThresholdFn`. - """ - k1, (k2, result) = element - yield k1, (k2, self._apply_threshold_to_predictions(result)) - - -class StatefulThresholdDoFn(BaseThresholdDoFn): - """Applies a stateful ThresholdFn to anomaly detection results. - - This DoFn is designed for stateful `ThresholdFn` implementations. It leverages - Beam's state management to persist and update the state of the `ThresholdFn` - across multiple elements. This is necessary for `ThresholdFn`s that need to - accumulate information or adapt over time, such as quantile-based thresholds. - - Args: - threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be - used. - - Raises: - AssertionError: If the provided `threshold_fn_spec` leads to the - creation of a stateless `ThresholdFn`. - """ - THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder()) - - def __init__(self, threshold_fn_spec: Spec): - threshold_fn_spec.config["_run_init"] = True - threshold_fn: Any = Specifiable.from_spec(threshold_fn_spec) - assert threshold_fn.is_stateful, \ - "This DoFn can only take stateful function as threshold_fn" - self._threshold_fn_spec = threshold_fn_spec - - def process( - self, - element: Tuple[Any, Tuple[Any, AnomalyResult]], - threshold_state: Union[ReadModifyWriteRuntimeState, - Any] = beam.DoFn.StateParam(THRESHOLD_STATE_INDEX), - **kwargs) -> Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: - """Processes a batch of anomaly results using a stateful ThresholdFn. - - For each input element, this DoFn retrieves the stateful `ThresholdFn` from - Beam state, initializes it if it's the first time, applies it to update - the prediction labels in the `AnomalyResult`, and then updates the state in - Beam for future elements. - - Args: - element (Tuple[Any, Tuple[Any, AnomalyResult]]): A tuple representing - an element in the Beam pipeline. It is expected to be in the format - `(key1, (key2, AnomalyResult))`, where key1 is the original input key, - and key2 is a disambiguating key for distinct data points. - threshold_state (Union[ReadModifyWriteRuntimeState, Any]): A Beam state - parameter that provides access to the persisted state of the - `ThresholdFn`. It is automatically managed by Beam. - **kwargs: Additional keyword arguments passed to the `process` method - in Beam DoFns. - - Yields: - Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: An iterable containing - a single output element with the same structure as the input, but - with the `AnomalyResult` having updated prediction labels based on - the stateful `ThresholdFn`. - """ - k1, (k2, result) = element - - self._threshold_fn = threshold_state.read() - if self._threshold_fn is None: - self._threshold_fn: Specifiable = Specifiable.from_spec( - self._threshold_fn_spec) - - yield k1, (k2, self._apply_threshold_to_predictions(result)) - - threshold_state.write(self._threshold_fn) @specifiable diff --git a/sdks/python/apache_beam/ml/anomaly/thresholds_test.py b/sdks/python/apache_beam/ml/anomaly/thresholds_test.py index 413c0e52c6b0..bd2629ee00c1 100644 --- a/sdks/python/apache_beam/ml/anomaly/thresholds_test.py +++ b/sdks/python/apache_beam/ml/anomaly/thresholds_test.py @@ -18,18 +18,10 @@ import logging import unittest -import apache_beam as beam from apache_beam.ml.anomaly import thresholds -from apache_beam.ml.anomaly.base import AnomalyPrediction -from apache_beam.ml.anomaly.base import AnomalyResult from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.univariate.quantile import BufferedSlidingQuantileTracker # pylint: disable=line-too-long from apache_beam.ml.anomaly.univariate.quantile import SimpleSlidingQuantileTracker # pylint: disable=line-too-long -from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import assert_that -from apache_beam.testing.util import equal_to - -R = beam.Row(x=10, y=20) class TestFixedThreshold(unittest.TestCase): @@ -40,93 +32,6 @@ def test_apply_only(self): self.assertEqual(threshold_fn.apply(None), None) self.assertEqual(threshold_fn.apply(float('NaN')), -2) - def test_dofn_on_single_prediction(self): - input = [ - (1, (2, AnomalyResult(R, [AnomalyPrediction(score=1)]))), - (1, (3, AnomalyResult(R, [AnomalyPrediction(score=2)]))), - (1, (4, AnomalyResult(R, [AnomalyPrediction(score=3)]))), - ] - expected = [ - ( - 1, - ( - 2, - AnomalyResult( - R, [AnomalyPrediction(score=1, label=0, threshold=2)]))), - ( - 1, - ( - 3, - AnomalyResult( - R, [AnomalyPrediction(score=2, label=1, threshold=2)]))), - ( - 1, - ( - 4, - AnomalyResult( - R, [AnomalyPrediction(score=3, label=1, threshold=2)]))), - ] - with TestPipeline() as p: - result = ( - p - | beam.Create(input) - | beam.ParDo( - thresholds.StatelessThresholdDoFn( - thresholds.FixedThreshold(2, normal_label=0, - outlier_label=1).to_spec()))) - assert_that(result, equal_to(expected)) - - def test_dofn_on_multiple_predictions(self): - input = [ - ( - 1, - ( - 2, - AnomalyResult( - R, - [AnomalyPrediction(score=1), AnomalyPrediction(score=4)]))), - ( - 1, - ( - 3, - AnomalyResult( - R, - [AnomalyPrediction(score=2), AnomalyPrediction(score=0.5) - ]))), - ] - expected = [ - ( - 1, - ( - 2, - AnomalyResult( - R, - [ - AnomalyPrediction(score=1, label=0, threshold=2), - AnomalyPrediction(score=4, label=1, threshold=2) - ]))), - ( - 1, - ( - 3, - AnomalyResult( - R, - [ - AnomalyPrediction(score=2, label=1, threshold=2), - AnomalyPrediction(score=0.5, label=0, threshold=2) - ]))), - ] - with TestPipeline() as p: - result = ( - p - | beam.Create(input) - | beam.ParDo( - thresholds.StatelessThresholdDoFn( - thresholds.FixedThreshold(2, normal_label=0, - outlier_label=1).to_spec()))) - - assert_that(result, equal_to(expected)) - class TestQuantileThreshold(unittest.TestCase): def test_apply_only(self): @@ -137,67 +42,6 @@ def test_apply_only(self): self.assertEqual(threshold_fn.apply(None), None) self.assertEqual(threshold_fn.apply(float('NaN')), -2) - def test_dofn_on_single_prediction(self): - # use the input data with two keys to test stateful threshold function - input = [ - (1, (2, AnomalyResult(R, [AnomalyPrediction(score=1)]))), - (1, (3, AnomalyResult(R, [AnomalyPrediction(score=2)]))), - (1, (4, AnomalyResult(R, [AnomalyPrediction(score=3)]))), - (2, (2, AnomalyResult(R, [AnomalyPrediction(score=10)]))), - (2, (3, AnomalyResult(R, [AnomalyPrediction(score=20)]))), - (2, (4, AnomalyResult(R, [AnomalyPrediction(score=30)]))), - ] - expected = [ - ( - 1, - ( - 2, - AnomalyResult( - R, [AnomalyPrediction(score=1, label=1, threshold=1)]))), - ( - 1, - ( - 3, - AnomalyResult( - R, [AnomalyPrediction(score=2, label=1, threshold=1.5)]))), - ( - 2, - ( - 2, - AnomalyResult( - R, [AnomalyPrediction(score=10, label=1, threshold=10)]))), - ( - 2, - ( - 3, - AnomalyResult( - R, [AnomalyPrediction(score=20, label=1, threshold=15)]))), - ( - 1, - ( - 4, - AnomalyResult( - R, [AnomalyPrediction(score=3, label=1, threshold=2)]))), - ( - 2, - ( - 4, - AnomalyResult( - R, [AnomalyPrediction(score=30, label=1, threshold=20)]))), - ] - with TestPipeline() as p: - result = ( - p - | beam.Create(input) - # use median just for test convenience - | beam.ParDo( - thresholds.StatefulThresholdDoFn( - thresholds.QuantileThreshold( - quantile=0.5, normal_label=0, - outlier_label=1).to_spec()))) - - assert_that(result, equal_to(expected)) - def test_quantile_tracker(self): t1 = thresholds.QuantileThreshold() self.assertTrue(isinstance(t1._tracker, BufferedSlidingQuantileTracker)) @@ -215,7 +59,6 @@ def test_quantile_tracker(self): quantile=0.9, quantile_tracker=SimpleSlidingQuantileTracker(50, 0.975)) self.assertTrue(isinstance(t3._tracker, SimpleSlidingQuantileTracker)) self.assertEqual(t3._tracker._q, 0.975) - print(t3.to_spec()) self.assertEqual( t3.to_spec(), Spec( diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py b/sdks/python/apache_beam/ml/anomaly/transforms.py index 7053a16f5f7b..35cae18a7224 100644 --- a/sdks/python/apache_beam/ml/anomaly/transforms.py +++ b/sdks/python/apache_beam/ml/anomaly/transforms.py @@ -15,6 +15,7 @@ # limitations under the License. # +import dataclasses import uuid from typing import Callable from typing import Iterable @@ -30,11 +31,9 @@ from apache_beam.ml.anomaly.base import AnomalyPrediction from apache_beam.ml.anomaly.base import AnomalyResult from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector +from apache_beam.ml.anomaly.base import ThresholdFn from apache_beam.ml.anomaly.specifiable import Spec from apache_beam.ml.anomaly.specifiable import Specifiable -from apache_beam.ml.anomaly.thresholds import StatefulThresholdDoFn -from apache_beam.ml.anomaly.thresholds import StatelessThresholdDoFn -from apache_beam.ml.anomaly.thresholds import ThresholdFn from apache_beam.transforms.userstate import ReadModifyWriteStateSpec KeyT = TypeVar('KeyT') @@ -121,6 +120,149 @@ def expand( return input | beam.ParDo(_ScoreAndLearnDoFn(self._detector.to_spec())) +class _BaseThresholdDoFn(beam.DoFn): + """Applies a ThresholdFn to anomaly detection results. + + This abstract base class defines the structure for DoFns that use a + `ThresholdFn` to convert anomaly scores into anomaly labels (e.g., normal + or outlier). It handles the core logic of applying the threshold function + and updating the prediction labels within `AnomalyResult` objects. + + Args: + threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be + used. + """ + def __init__(self, threshold_fn_spec: Spec): + self._threshold_fn_spec = threshold_fn_spec + + def _apply_threshold_to_predictions( + self, result: AnomalyResult) -> AnomalyResult: + """Updates the prediction labels in an AnomalyResult using the ThresholdFn. + + Args: + result (AnomalyResult): The input `AnomalyResult` containing anomaly + scores. + + Returns: + AnomalyResult: A new `AnomalyResult` with updated prediction labels + and threshold values. + """ + predictions = [ + dataclasses.replace( + p, + label=self._threshold_fn.apply(p.score), + threshold=self._threshold_fn.threshold) for p in result.predictions + ] + return dataclasses.replace(result, predictions=predictions) + + +class _StatelessThresholdDoFn(_BaseThresholdDoFn): + """Applies a stateless ThresholdFn to anomaly detection results. + + This DoFn is designed for stateless `ThresholdFn` implementations. It + initializes the `ThresholdFn` once during setup and applies it to each + incoming element without maintaining any state across elements. + + Args: + threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be + used. + + Raises: + AssertionError: If the provided `threshold_fn_spec` leads to the + creation of a stateful `ThresholdFn`. + """ + def __init__(self, threshold_fn_spec: Spec): + threshold_fn_spec.config["_run_init"] = True + self._threshold_fn = Specifiable.from_spec(threshold_fn_spec) + assert not self._threshold_fn.is_stateful, \ + "This DoFn can only take stateless function as threshold_fn" + + def process(self, element: KeyedOutputT, **kwargs) -> Iterable[KeyedOutputT]: + """Processes a batch of anomaly results using a stateless ThresholdFn. + + Args: + element (Tuple[Any, Tuple[Any, AnomalyResult]]): A tuple representing + an element in the Beam pipeline. It is expected to be in the format + `(key1, (key2, AnomalyResult))`, where key1 is the original input key, + and key2 is a disambiguating key for distinct data points. + **kwargs: Additional keyword arguments passed to the `process` method + in Beam DoFns. + + Yields: + Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: An iterable containing + a single output element with the same structure as the input, but with + the `AnomalyResult` having updated prediction labels based on the + stateless `ThresholdFn`. + """ + k1, (k2, result) = element + yield k1, (k2, self._apply_threshold_to_predictions(result)) + + +class _StatefulThresholdDoFn(_BaseThresholdDoFn): + """Applies a stateful ThresholdFn to anomaly detection results. + + This DoFn is designed for stateful `ThresholdFn` implementations. It leverages + Beam's state management to persist and update the state of the `ThresholdFn` + across multiple elements. This is necessary for `ThresholdFn`s that need to + accumulate information or adapt over time, such as quantile-based thresholds. + + Args: + threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be + used. + + Raises: + AssertionError: If the provided `threshold_fn_spec` leads to the + creation of a stateless `ThresholdFn`. + """ + THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder()) + + def __init__(self, threshold_fn_spec: Spec): + threshold_fn_spec.config["_run_init"] = True + threshold_fn = Specifiable.from_spec(threshold_fn_spec) + assert threshold_fn.is_stateful, \ + "This DoFn can only take stateful function as threshold_fn" + self._threshold_fn_spec = threshold_fn_spec + + def process( + self, + element: KeyedOutputT, + threshold_state=beam.DoFn.StateParam(THRESHOLD_STATE_INDEX), + **kwargs) -> Iterable[KeyedOutputT]: + """Processes a batch of anomaly results using a stateful ThresholdFn. + + For each input element, this DoFn retrieves the stateful `ThresholdFn` from + Beam state, initializes it if it's the first time, applies it to update + the prediction labels in the `AnomalyResult`, and then updates the state in + Beam for future elements. + + Args: + element (Tuple[Any, Tuple[Any, AnomalyResult]]): A tuple representing + an element in the Beam pipeline. It is expected to be in the format + `(key1, (key2, AnomalyResult))`, where key1 is the original input key, + and key2 is a disambiguating key for distinct data points. + threshold_state: A Beam state parameter that provides access to the + persisted state of the `ThresholdFn`. It is automatically managed by + Beam. + **kwargs: Additional keyword arguments passed to the `process` method + in Beam DoFns. + + Yields: + Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: An iterable containing + a single output element with the same structure as the input, but + with the `AnomalyResult` having updated prediction labels based on + the stateful `ThresholdFn`. + """ + k1, (k2, result) = element + + self._threshold_fn = threshold_state.read() + if self._threshold_fn is None: + self._threshold_fn = Specifiable.from_spec(self._threshold_fn_spec) + + yield k1, (k2, self._apply_threshold_to_predictions(result)) + + threshold_state.write(self._threshold_fn) + + class RunThresholdCriterion(beam.PTransform[beam.PCollection[KeyedOutputT], beam.PCollection[KeyedOutputT]]): """Applies a threshold criterion to anomaly detection results. @@ -142,11 +284,11 @@ def expand( if self._threshold_fn.is_stateful: return ( input - | beam.ParDo(StatefulThresholdDoFn(self._threshold_fn.to_spec()))) + | beam.ParDo(_StatefulThresholdDoFn(self._threshold_fn.to_spec()))) else: return ( input - | beam.ParDo(StatelessThresholdDoFn(self._threshold_fn.to_spec()))) + | beam.ParDo(_StatelessThresholdDoFn(self._threshold_fn.to_spec()))) class RunAggregationStrategy(beam.PTransform[beam.PCollection[KeyedOutputT], diff --git a/sdks/python/apache_beam/ml/anomaly/transforms_test.py b/sdks/python/apache_beam/ml/anomaly/transforms_test.py index cdb869008dd5..cf398728f372 100644 --- a/sdks/python/apache_beam/ml/anomaly/transforms_test.py +++ b/sdks/python/apache_beam/ml/anomaly/transforms_test.py @@ -27,7 +27,10 @@ from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector from apache_beam.ml.anomaly.detectors.zscore import ZScore from apache_beam.ml.anomaly.thresholds import FixedThreshold +from apache_beam.ml.anomaly.thresholds import QuantileThreshold from apache_beam.ml.anomaly.transforms import AnomalyDetection +from apache_beam.ml.anomaly.transforms import _StatefulThresholdDoFn +from apache_beam.ml.anomaly.transforms import _StatelessThresholdDoFn from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -250,6 +253,161 @@ def test_multiple_sub_detectors_with_aggregation(self): prediction in zip(self._input, aggregated)])) +R = beam.Row(x=10, y=20) + + +class TestStatelessThresholdDoFn(unittest.TestCase): + def test_dofn_on_single_prediction(self): + input = [ + (1, (2, AnomalyResult(R, [AnomalyPrediction(score=1)]))), + (1, (3, AnomalyResult(R, [AnomalyPrediction(score=2)]))), + (1, (4, AnomalyResult(R, [AnomalyPrediction(score=3)]))), + ] + expected = [ + ( + 1, + ( + 2, + AnomalyResult( + R, [AnomalyPrediction(score=1, label=0, threshold=2)]))), + ( + 1, + ( + 3, + AnomalyResult( + R, [AnomalyPrediction(score=2, label=1, threshold=2)]))), + ( + 1, + ( + 4, + AnomalyResult( + R, [AnomalyPrediction(score=3, label=1, threshold=2)]))), + ] + with TestPipeline() as p: + result = ( + p + | beam.Create(input) + | beam.ParDo( + _StatelessThresholdDoFn( + FixedThreshold(2, normal_label=0, + outlier_label=1).to_spec()))) + assert_that(result, equal_to(expected)) + + def test_dofn_on_multiple_predictions(self): + input = [ + ( + 1, + ( + 2, + AnomalyResult( + R, + [AnomalyPrediction(score=1), AnomalyPrediction(score=4)]))), + ( + 1, + ( + 3, + AnomalyResult( + R, + [AnomalyPrediction(score=2), AnomalyPrediction(score=0.5) + ]))), + ] + expected = [ + ( + 1, + ( + 2, + AnomalyResult( + R, + [ + AnomalyPrediction(score=1, label=0, threshold=2), + AnomalyPrediction(score=4, label=1, threshold=2) + ]))), + ( + 1, + ( + 3, + AnomalyResult( + R, + [ + AnomalyPrediction(score=2, label=1, threshold=2), + AnomalyPrediction(score=0.5, label=0, threshold=2) + ]))), + ] + with TestPipeline() as p: + result = ( + p + | beam.Create(input) + | beam.ParDo( + _StatelessThresholdDoFn( + FixedThreshold(2, normal_label=0, + outlier_label=1).to_spec()))) + + assert_that(result, equal_to(expected)) + + +class TestStatefulThresholdDoFn(unittest.TestCase): + def test_dofn_on_single_prediction(self): + # use the input data with two keys to test stateful threshold function + input = [ + (1, (2, AnomalyResult(R, [AnomalyPrediction(score=1)]))), + (1, (3, AnomalyResult(R, [AnomalyPrediction(score=2)]))), + (1, (4, AnomalyResult(R, [AnomalyPrediction(score=3)]))), + (2, (2, AnomalyResult(R, [AnomalyPrediction(score=10)]))), + (2, (3, AnomalyResult(R, [AnomalyPrediction(score=20)]))), + (2, (4, AnomalyResult(R, [AnomalyPrediction(score=30)]))), + ] + expected = [ + ( + 1, + ( + 2, + AnomalyResult( + R, [AnomalyPrediction(score=1, label=1, threshold=1)]))), + ( + 1, + ( + 3, + AnomalyResult( + R, [AnomalyPrediction(score=2, label=1, threshold=1.5)]))), + ( + 2, + ( + 2, + AnomalyResult( + R, [AnomalyPrediction(score=10, label=1, threshold=10)]))), + ( + 2, + ( + 3, + AnomalyResult( + R, [AnomalyPrediction(score=20, label=1, threshold=15)]))), + ( + 1, + ( + 4, + AnomalyResult( + R, [AnomalyPrediction(score=3, label=1, threshold=2)]))), + ( + 2, + ( + 4, + AnomalyResult( + R, [AnomalyPrediction(score=30, label=1, threshold=20)]))), + ] + with TestPipeline() as p: + result = ( + p + | beam.Create(input) + # use median just for test convenience + | beam.ParDo( + _StatefulThresholdDoFn( + QuantileThreshold( + quantile=0.5, normal_label=0, + outlier_label=1).to_spec()))) + + assert_that(result, equal_to(expected)) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.WARNING) unittest.main()