From d5aa691417febb7f25e591db26f18b31fca3405f Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Mon, 10 Mar 2025 12:27:47 -0400 Subject: [PATCH 1/2] Add main and auxiliary transforms. --- .../apache_beam/ml/anomaly/transforms.py | 381 ++++++++++++++++++ .../apache_beam/ml/anomaly/transforms_test.py | 255 ++++++++++++ 2 files changed, 636 insertions(+) create mode 100644 sdks/python/apache_beam/ml/anomaly/transforms.py create mode 100644 sdks/python/apache_beam/ml/anomaly/transforms_test.py diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py b/sdks/python/apache_beam/ml/anomaly/transforms.py new file mode 100644 index 000000000000..136ad1c20e30 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/transforms.py @@ -0,0 +1,381 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing +import uuid +from typing import Any +from typing import Callable +from typing import Iterable +from typing import Tuple +from typing import TypeVar + +import apache_beam as beam +from apache_beam.coders import DillCoder +from apache_beam.ml.anomaly import aggregations +from apache_beam.ml.anomaly.base import AggregationFn +from apache_beam.ml.anomaly.base import AnomalyDetector +from apache_beam.ml.anomaly.base import AnomalyPrediction +from apache_beam.ml.anomaly.base import AnomalyResult +from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector +from apache_beam.ml.anomaly.specifiable import Spec +from apache_beam.ml.anomaly.specifiable import Specifiable +from apache_beam.ml.anomaly.thresholds import StatefulThresholdDoFn +from apache_beam.ml.anomaly.thresholds import StatelessThresholdDoFn +from apache_beam.transforms.userstate import ReadModifyWriteRuntimeState +from apache_beam.transforms.userstate import ReadModifyWriteStateSpec + +KeyT = TypeVar('KeyT') +TempKeyT = TypeVar('TempKeyT', bound=int) +InputT = Tuple[KeyT, beam.Row] +KeyedInputT = Tuple[KeyT, Tuple[TempKeyT, beam.Row]] +KeyedOutputT = Tuple[KeyT, Tuple[TempKeyT, AnomalyResult]] +OutputT = Tuple[KeyT, AnomalyResult] + + +class _ScoreAndLearnDoFn(beam.DoFn): + """Scores and learns from incoming data using an anomaly detection model. + + This DoFn applies an anomaly detection model to score incoming data and + then updates the model with the same data. It maintains the model state + using Beam's state management. + """ + MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder()) + + def __init__(self, detector_spec: Spec): + self._detector_spec = detector_spec + self._detector_spec.config["_run_init"] = True + + def score_and_learn(self, data): + """Scores and learns from a single data point. + + Args: + data: A `beam.Row` representing the input data point. + + Returns: + float: The anomaly score predicted by the model. + """ + assert self._underlying + if self._underlying._features is not None: + x = beam.Row(**{f: getattr(data, f) for f in self._underlying._features}) + else: + x = beam.Row(**data._asdict()) + + # score the incoming data using the existing model + y_pred = self._underlying.score_one(x) + + # then update the model with the same data + self._underlying.learn_one(x) + + return y_pred + + def process( + self, + element: KeyedInputT, + model_state=beam.DoFn.StateParam(MODEL_STATE_INDEX), + **kwargs) -> Iterable[KeyedOutputT]: + + model_state = typing.cast(ReadModifyWriteRuntimeState, model_state) + k1, (k2, data) = element + self._underlying: AnomalyDetector = model_state.read() + if self._underlying is None: + self._underlying = typing.cast( + AnomalyDetector, Specifiable.from_spec(self._detector_spec)) + + yield k1, (k2, + AnomalyResult( + example=data, + predictions=[AnomalyPrediction( + model_id=self._underlying._model_id, + score=self.score_and_learn(data))])) + + model_state.write(self._underlying) + + +class RunScoreAndLearn(beam.PTransform[beam.PCollection[KeyedInputT], + beam.PCollection[KeyedOutputT]]): + """Applies the _ScoreAndLearnDoFn to a PCollection of data. + + This PTransform scores and learns from data points using an anomaly + detection model. + + Args: + detector: The anomaly detection model to use. + """ + def __init__(self, detector: AnomalyDetector): + self._detector = detector + + def expand( + self, + input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]: + return input | beam.ParDo(_ScoreAndLearnDoFn(self._detector.to_spec())) + + +class RunThresholdCriterion(beam.PTransform[beam.PCollection[KeyedOutputT], + beam.PCollection[KeyedOutputT]]): + """Applies a threshold criterion to anomaly detection results. + + This PTransform applies a `ThresholdFn` to the anomaly scores in + `AnomalyResult` objects, updating the prediction labels. It handles both + stateful and stateless `ThresholdFn` implementations. + + Args: + threshold_criterion: The `ThresholdFn` to apply. + """ + def __init__(self, threshold_criterion): + self._threshold_fn = threshold_criterion + + def expand( + self, + input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]: + if self._threshold_fn: + if self._threshold_fn.is_stateful: + ret = ( + input + | beam.ParDo(StatefulThresholdDoFn(self._threshold_fn.to_spec()))) + else: + ret = ( + input + | beam.ParDo(StatelessThresholdDoFn(self._threshold_fn.to_spec()))) + else: + ret = input + + return ret + + +class RunAggregationStrategy(beam.PTransform[beam.PCollection[KeyedOutputT], + beam.PCollection[KeyedOutputT]]): + """Applies an aggregation strategy to grouped anomaly detection results. + + This PTransform aggregates anomaly predictions from multiple models or + data points using an `AggregationFn`. It handles both custom and simple + aggregation strategies. + + Args: + aggregation_strategy: The `AggregationFn` to use. + agg_model_id: The model ID for aggregation. + """ + def __init__(self, aggregation_strategy, agg_model_id): + self._aggregation_fn = aggregation_strategy + self._agg_model_id = agg_model_id + + def expand( + self, + input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]: + post_gbk = ( + input | beam.MapTuple(lambda k, v: ((k, v[0]), v[1])) + | beam.GroupByKey()) + + if self._aggregation_fn is None: + # simply put predictions into an iterable (list) + ret: Any = ( + post_gbk | beam.MapTuple( + lambda k, + v: ( + k[0], + ( + k[1], + AnomalyResult( + example=v[0].example, + predictions=[ + prediction for result in v + for prediction in result.predictions + ]))))) + return ret + + # create a new aggregation_fn from spec and make sure it is initialized + aggregation_fn_spec = self._aggregation_fn.to_spec() + aggregation_fn_spec.config["_run_init"] = True + aggregation_fn: AggregationFn = typing.cast( + AggregationFn, Specifiable.from_spec(aggregation_fn_spec)) + + # if no _agg_model_id is set in the aggregation function, use + # model id from the ensemble instance + if (isinstance(aggregation_fn, aggregations._AggModelIdMixin)): + aggregation_fn._set_agg_model_id_if_unset(self._agg_model_id) + + ret = ( + post_gbk | beam.MapTuple( + lambda k, + v, + agg=aggregation_fn: ( + k[0], + ( + k[1], + AnomalyResult( + example=v[0].example, + predictions=[ + agg.apply([ + prediction for result in v + for prediction in result.predictions + ]) + ]))))) + return ret + + +class RunOneDetector(beam.PTransform[beam.PCollection[KeyedInputT], + beam.PCollection[KeyedOutputT]]): + """Runs a single anomaly detector on a PCollection of data. + + This PTransform applies a single `AnomalyDetector` to the input data, + including scoring, learning, and thresholding. + + Args: + detector: The `AnomalyDetector` to run. + """ + def __init__(self, detector): + self._detector = detector + + def expand( + self, + input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]: + model_id = getattr( + self._detector, + "_model_id", + getattr(self._detector, "_key", "unknown_model")) + model_uuid = f"{model_id}:{uuid.uuid4().hex[:6]}" + + ret: Any = ( + input + | beam.Reshuffle() + | f"Score and Learn ({model_uuid})" >> RunScoreAndLearn(self._detector) + | f"Run Threshold Criterion ({model_uuid})" >> RunThresholdCriterion( + self._detector._threshold_criterion)) + + return ret + + +class RunEnsembleDetector(beam.PTransform[beam.PCollection[KeyedInputT], + beam.PCollection[KeyedOutputT]]): + """Runs an ensemble of anomaly detectors on a PCollection of data. + + This PTransform applies an `EnsembleAnomalyDetector` to the input data, + running each sub-detector and aggregating the results. + + Args: + ensemble_detector: The `EnsembleAnomalyDetector` to run. + """ + def __init__(self, ensemble_detector: EnsembleAnomalyDetector): + self._ensemble_detector = ensemble_detector + + def expand( + self, + input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]: + model_uuid = f"{self._ensemble_detector._model_id}:{uuid.uuid4().hex[:6]}" + + assert self._ensemble_detector._sub_detectors is not None + if not self._ensemble_detector._sub_detectors: + raise ValueError(f"No detectors found at {model_uuid}") + + results = [] + for idx, detector in enumerate(self._ensemble_detector._sub_detectors): + if isinstance(detector, EnsembleAnomalyDetector): + results.append( + input | f"Run Ensemble Detector at index {idx} ({model_uuid})" >> + RunEnsembleDetector(detector)) + else: + results.append( + input + | f"Run One Detector at index {idx} ({model_uuid})" >> + RunOneDetector(detector)) + + if self._ensemble_detector._aggregation_strategy is None: + aggregation_type = "Simple" + else: + aggregation_type = "Custom" + + aggregated = ( + results | beam.Flatten() + | f"Run {aggregation_type} Aggregation Strategy ({model_uuid})" >> + RunAggregationStrategy( + self._ensemble_detector._aggregation_strategy, + self._ensemble_detector._model_id)) + + ret: Any = ( + aggregated + | f"Run Threshold Criterion ({model_uuid})" >> RunThresholdCriterion( + self._ensemble_detector._threshold_criterion)) + + return ret + + +class AnomalyDetection(beam.PTransform[beam.PCollection[InputT], + beam.PCollection[OutputT]]): + """Performs anomaly detection on a PCollection of data. + + This PTransform applies an `AnomalyDetector` or `EnsembleAnomalyDetector` to + the input data and returns a PCollection of `AnomalyResult` objects. + + Examples:: + + # Run a single anomaly detector + p | AnomalyDetection(ZScore(features=["x1"])) + + # Run an ensemble anomaly detector + sub_detectors = [ZScore(features=["x1"]), IQR(features=["x2"])] + p | AnomalyDetection( + EnsembleAnomalyDetector(sub_detectors, aggregation_strategy=AnyVote())) + + Args: + detector: The `AnomalyDetector` or `EnsembleAnomalyDetector` to use. + """ + def __init__( + self, + detector: AnomalyDetector, + ) -> None: + self._root_detector = detector + + def expand( + self, + input: beam.PCollection[InputT], + ) -> beam.PCollection[OutputT]: + + # Add a temporary unique key per data point to facilitate grouping the + # outputs from multiple anomaly detectors for the same data point. + # + # Unique key generation options: + # (1) Timestamp-based methods: https://docs.python.org/3/library/time.html + # (2) UUID module: https://docs.python.org/3/library/uuid.html + # + # Timestamp precision on Windows can lead to key collisions (see PEP 564: + # https://peps.python.org/pep-0564/#windows). Only time.perf_counter_ns() + # provides sufficient precision for our needs. + # + # Performance note: + # $ python -m timeit -n 100000 "import uuid; uuid.uuid1()" + # 100000 loops, best of 5: 806 nsec per loop + # $ python -m timeit -n 100000 "import uuid; uuid.uuid4()" + # 100000 loops, best of 5: 1.53 usec per loop + # $ python -m timeit -n 100000 "import time; time.perf_counter_ns()" + # 100000 loops, best of 5: 82.3 nsec per loop + # + # We select uuid.uuid1() for its inclusion of node information, making it + # more suitable for parallel execution environments. + add_temp_key_fn: Callable[[InputT], KeyedInputT] \ + = lambda e: (e[0], (str(uuid.uuid1()), e[1])) + keyed_input = (input | "Add temp key" >> beam.Map(add_temp_key_fn)) + + if isinstance(self._root_detector, EnsembleAnomalyDetector): + keyed_output = (keyed_input | RunEnsembleDetector(self._root_detector)) + else: + keyed_output = (keyed_input | RunOneDetector(self._root_detector)) + + # remove the temporary key and simplify the output. + remove_temp_key_fn: Callable[[KeyedOutputT], OutputT] \ + = lambda e: (e[0], e[1][1]) + ret: Any = keyed_output | "Remove temp key" >> beam.Map(remove_temp_key_fn) + + return ret diff --git a/sdks/python/apache_beam/ml/anomaly/transforms_test.py b/sdks/python/apache_beam/ml/anomaly/transforms_test.py new file mode 100644 index 000000000000..cdb869008dd5 --- /dev/null +++ b/sdks/python/apache_beam/ml/anomaly/transforms_test.py @@ -0,0 +1,255 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import math +import unittest +from typing import Iterable + +import apache_beam as beam +from apache_beam.ml.anomaly.aggregations import AnyVote +from apache_beam.ml.anomaly.base import AnomalyPrediction +from apache_beam.ml.anomaly.base import AnomalyResult +from apache_beam.ml.anomaly.base import EnsembleAnomalyDetector +from apache_beam.ml.anomaly.detectors.zscore import ZScore +from apache_beam.ml.anomaly.thresholds import FixedThreshold +from apache_beam.ml.anomaly.transforms import AnomalyDetection +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +def _prediction_iterable_is_equal_to( + a: Iterable[AnomalyPrediction], b: Iterable[AnomalyPrediction]): + a_list = list(a) + b_list = list(b) + + if len(a_list) != len(b_list): + return False + + return any( + map(lambda x: _prediction_is_equal_to(x[0], x[1]), zip(a_list, b_list))) + + +def _prediction_is_equal_to(a: AnomalyPrediction, b: AnomalyPrediction): + if a.model_id != b.model_id: + return False + + if a.threshold != b.threshold: + return False + + if a.score != b.score: + if not (a.score is not None and b.score is not None and + math.isnan(a.score) and math.isnan(b.score)): + return False + + if a.label != b.label: + return False + + if a.info != b.info: + return False + + if a.source_predictions is None and b.source_predictions is None: + return True + + if a.source_predictions is not None and b.source_predictions is not None: + return _prediction_iterable_is_equal_to( + a.source_predictions, b.source_predictions) + + return False + + +def _keyed_result_is_equal_to( + a: tuple[int, AnomalyResult], b: tuple[int, AnomalyResult]): + if a[0] != b[0]: + return False + + if a[1].example != b[1].example: + return False + + return _prediction_iterable_is_equal_to(a[1].predictions, b[1].predictions) + + +class TestAnomalyDetection(unittest.TestCase): + def setUp(self): + self._input = [ + (1, beam.Row(x1=1, x2=4)), + (2, beam.Row(x1=100, x2=5)), # an row with a different key (key=2) + (1, beam.Row(x1=2, x2=4)), + (1, beam.Row(x1=3, x2=5)), + (1, beam.Row(x1=10, x2=4)), # outlier in key=1, with respect to x1 + (1, beam.Row(x1=2, x2=10)), # outlier in key=1, with respect to x2 + (1, beam.Row(x1=3, x2=4)), + ] + + def test_one_detector(self): + zscore_x1_expected = [ + AnomalyPrediction( + model_id='zscore_x1', score=float('NaN'), label=-2, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', score=float('NaN'), label=-2, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', score=float('NaN'), label=-2, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', + score=2.1213203435596424, + label=0, + threshold=3), + AnomalyPrediction( + model_id='zscore_x1', score=8.0, label=1, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', + score=0.4898979485566356, + label=0, + threshold=3), + AnomalyPrediction( + model_id='zscore_x1', + score=0.16452254913212455, + label=0, + threshold=3), + ] + detector = ZScore(features=["x1"], model_id="zscore_x1") + + with TestPipeline() as p: + result = ( + p | beam.Create(self._input) + # TODO: get rid of this conversion between BeamSchema to beam.Row. + | beam.Map(lambda t: (t[0], beam.Row(**t[1]._asdict()))) + | AnomalyDetection(detector)) + assert_that( + result, + equal_to([( + input[0], AnomalyResult(example=input[1], predictions=[decision])) + for input, + decision in zip(self._input, zscore_x1_expected)], + _keyed_result_is_equal_to)) + + def test_multiple_detectors_without_aggregation(self): + zscore_x1_expected = [ + AnomalyPrediction( + model_id='zscore_x1', score=float('NaN'), label=-2, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', score=float('NaN'), label=-2, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', score=float('NaN'), label=-2, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', + score=2.1213203435596424, + label=0, + threshold=3), + AnomalyPrediction( + model_id='zscore_x1', score=8.0, label=1, threshold=3), + AnomalyPrediction( + model_id='zscore_x1', + score=0.4898979485566356, + label=0, + threshold=3), + AnomalyPrediction( + model_id='zscore_x1', + score=0.16452254913212455, + label=0, + threshold=3), + ] + zscore_x2_expected = [ + AnomalyPrediction( + model_id='zscore_x2', score=float('NaN'), label=-2, threshold=2), + AnomalyPrediction( + model_id='zscore_x2', score=float('NaN'), label=-2, threshold=2), + AnomalyPrediction( + model_id='zscore_x2', score=float('NaN'), label=-2, threshold=2), + AnomalyPrediction(model_id='zscore_x2', score=0, label=0, threshold=2), + AnomalyPrediction( + model_id='zscore_x2', + score=0.5773502691896252, + label=0, + threshold=2), + AnomalyPrediction( + model_id='zscore_x2', score=11.5, label=1, threshold=2), + AnomalyPrediction( + model_id='zscore_x2', + score=0.5368754921931594, + label=0, + threshold=2), + ] + + sub_detectors = [] + sub_detectors.append(ZScore(features=["x1"], model_id="zscore_x1")) + sub_detectors.append( + ZScore( + features=["x2"], + threshold_criterion=FixedThreshold(2), + model_id="zscore_x2")) + + with beam.Pipeline() as p: + result = ( + p | beam.Create(self._input) + # TODO: get rid of this conversion between BeamSchema to beam.Row. + | beam.Map(lambda t: (t[0], beam.Row(**t[1]._asdict()))) + | AnomalyDetection(EnsembleAnomalyDetector(sub_detectors))) + + assert_that( + result, + equal_to([( + input[0], + AnomalyResult( + example=input[1], predictions=[decision1, decision2])) + for input, + decision1, + decision2 in zip( + self._input, zscore_x1_expected, zscore_x2_expected)], + _keyed_result_is_equal_to)) + + def test_multiple_sub_detectors_with_aggregation(self): + aggregated = [ + AnomalyPrediction(model_id="custom", label=-2), + AnomalyPrediction(model_id="custom", label=-2), + AnomalyPrediction(model_id="custom", label=-2), + AnomalyPrediction(model_id="custom", label=0), + AnomalyPrediction(model_id="custom", label=1), + AnomalyPrediction(model_id="custom", label=1), + AnomalyPrediction(model_id="custom", label=0), + ] + + sub_detectors = [] + sub_detectors.append(ZScore(features=["x1"], model_id="zscore_x1")) + sub_detectors.append( + ZScore( + features=["x2"], + threshold_criterion=FixedThreshold(2), + model_id="zscore_x2")) + + with beam.Pipeline() as p: + result = ( + p | beam.Create(self._input) + # TODO: get rid of this conversion between BeamSchema to beam.Row. + | beam.Map(lambda t: (t[0], beam.Row(**t[1]._asdict()))) + | AnomalyDetection( + EnsembleAnomalyDetector( + sub_detectors, aggregation_strategy=AnyVote()))) + + assert_that( + result, + equal_to([( + input[0], + AnomalyResult(example=input[1], predictions=[prediction])) + for input, + prediction in zip(self._input, aggregated)])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.WARNING) + unittest.main() From 2ee03931494715465b464ad0a4afc65f3599d514 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 11 Mar 2025 13:23:51 -0400 Subject: [PATCH 2/2] Minor fix per reviewer's feedback and fix lints. --- .../apache_beam/ml/anomaly/transforms.py | 65 +++++++++---------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py b/sdks/python/apache_beam/ml/anomaly/transforms.py index 136ad1c20e30..7053a16f5f7b 100644 --- a/sdks/python/apache_beam/ml/anomaly/transforms.py +++ b/sdks/python/apache_beam/ml/anomaly/transforms.py @@ -15,11 +15,10 @@ # limitations under the License. # -import typing import uuid -from typing import Any from typing import Callable from typing import Iterable +from typing import Optional from typing import Tuple from typing import TypeVar @@ -35,7 +34,7 @@ from apache_beam.ml.anomaly.specifiable import Specifiable from apache_beam.ml.anomaly.thresholds import StatefulThresholdDoFn from apache_beam.ml.anomaly.thresholds import StatelessThresholdDoFn -from apache_beam.transforms.userstate import ReadModifyWriteRuntimeState +from apache_beam.ml.anomaly.thresholds import ThresholdFn from apache_beam.transforms.userstate import ReadModifyWriteStateSpec KeyT = TypeVar('KeyT') @@ -88,12 +87,10 @@ def process( model_state=beam.DoFn.StateParam(MODEL_STATE_INDEX), **kwargs) -> Iterable[KeyedOutputT]: - model_state = typing.cast(ReadModifyWriteRuntimeState, model_state) k1, (k2, data) = element self._underlying: AnomalyDetector = model_state.read() if self._underlying is None: - self._underlying = typing.cast( - AnomalyDetector, Specifiable.from_spec(self._detector_spec)) + self._underlying = Specifiable.from_spec(self._detector_spec) yield k1, (k2, AnomalyResult( @@ -135,25 +132,21 @@ class RunThresholdCriterion(beam.PTransform[beam.PCollection[KeyedOutputT], Args: threshold_criterion: The `ThresholdFn` to apply. """ - def __init__(self, threshold_criterion): + def __init__(self, threshold_criterion: ThresholdFn): self._threshold_fn = threshold_criterion def expand( self, input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]: - if self._threshold_fn: - if self._threshold_fn.is_stateful: - ret = ( - input - | beam.ParDo(StatefulThresholdDoFn(self._threshold_fn.to_spec()))) - else: - ret = ( - input - | beam.ParDo(StatelessThresholdDoFn(self._threshold_fn.to_spec()))) - else: - ret = input - return ret + if self._threshold_fn.is_stateful: + return ( + input + | beam.ParDo(StatefulThresholdDoFn(self._threshold_fn.to_spec()))) + else: + return ( + input + | beam.ParDo(StatelessThresholdDoFn(self._threshold_fn.to_spec()))) class RunAggregationStrategy(beam.PTransform[beam.PCollection[KeyedOutputT], @@ -168,7 +161,8 @@ class RunAggregationStrategy(beam.PTransform[beam.PCollection[KeyedOutputT], aggregation_strategy: The `AggregationFn` to use. agg_model_id: The model ID for aggregation. """ - def __init__(self, aggregation_strategy, agg_model_id): + def __init__( + self, aggregation_strategy: Optional[AggregationFn], agg_model_id: str): self._aggregation_fn = aggregation_strategy self._agg_model_id = agg_model_id @@ -181,7 +175,7 @@ def expand( if self._aggregation_fn is None: # simply put predictions into an iterable (list) - ret: Any = ( + ret = ( post_gbk | beam.MapTuple( lambda k, v: ( @@ -199,14 +193,16 @@ def expand( # create a new aggregation_fn from spec and make sure it is initialized aggregation_fn_spec = self._aggregation_fn.to_spec() aggregation_fn_spec.config["_run_init"] = True - aggregation_fn: AggregationFn = typing.cast( - AggregationFn, Specifiable.from_spec(aggregation_fn_spec)) + aggregation_fn = Specifiable.from_spec(aggregation_fn_spec) # if no _agg_model_id is set in the aggregation function, use # model id from the ensemble instance if (isinstance(aggregation_fn, aggregations._AggModelIdMixin)): aggregation_fn._set_agg_model_id_if_unset(self._agg_model_id) + # post_gbk is a PCollection of ((original_key, temp_key), AnomalyResult). + # We use (original_key, temp_key) as the key for GroupByKey() so that + # scores from multiple detectors per data point are grouped. ret = ( post_gbk | beam.MapTuple( lambda k, @@ -248,12 +244,15 @@ def expand( getattr(self._detector, "_key", "unknown_model")) model_uuid = f"{model_id}:{uuid.uuid4().hex[:6]}" - ret: Any = ( + ret = ( input | beam.Reshuffle() - | f"Score and Learn ({model_uuid})" >> RunScoreAndLearn(self._detector) - | f"Run Threshold Criterion ({model_uuid})" >> RunThresholdCriterion( - self._detector._threshold_criterion)) + | f"Score and Learn ({model_uuid})" >> RunScoreAndLearn(self._detector)) + + if self._detector._threshold_criterion: + ret = ( + ret | f"Run Threshold Criterion ({model_uuid})" >> + RunThresholdCriterion(self._detector._threshold_criterion)) return ret @@ -297,17 +296,17 @@ def expand( else: aggregation_type = "Custom" - aggregated = ( + ret = ( results | beam.Flatten() | f"Run {aggregation_type} Aggregation Strategy ({model_uuid})" >> RunAggregationStrategy( self._ensemble_detector._aggregation_strategy, self._ensemble_detector._model_id)) - ret: Any = ( - aggregated - | f"Run Threshold Criterion ({model_uuid})" >> RunThresholdCriterion( - self._ensemble_detector._threshold_criterion)) + if self._ensemble_detector._threshold_criterion: + ret = ( + ret | f"Run Threshold Criterion ({model_uuid})" >> + RunThresholdCriterion(self._ensemble_detector._threshold_criterion)) return ret @@ -376,6 +375,6 @@ def expand( # remove the temporary key and simplify the output. remove_temp_key_fn: Callable[[KeyedOutputT], OutputT] \ = lambda e: (e[0], e[1][1]) - ret: Any = keyed_output | "Remove temp key" >> beam.Map(remove_temp_key_fn) + ret = keyed_output | "Remove temp key" >> beam.Map(remove_temp_key_fn) return ret