Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions sdks/python/apache_beam/ml/anomaly/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from typing import Iterable
from typing import Optional

from apache_beam.ml.anomaly.base import DEFAULT_MISSING_LABEL
from apache_beam.ml.anomaly.base import DEFAULT_NORMAL_LABEL
from apache_beam.ml.anomaly.base import DEFAULT_OUTLIER_LABEL
from apache_beam.ml.anomaly.base import AggregationFn
from apache_beam.ml.anomaly.base import AnomalyPrediction
from apache_beam.ml.anomaly.specifiable import specifiable
Expand Down Expand Up @@ -69,9 +72,13 @@ def __init__(
agg_func: Callable[[Iterable[int]], int],
agg_model_id: Optional[str] = None,
include_source_predictions: bool = False,
missing_label: int = -2,
normal_label: int = DEFAULT_NORMAL_LABEL,
outlier_label: int = DEFAULT_OUTLIER_LABEL,
missing_label: int = DEFAULT_MISSING_LABEL,
):
self._agg = agg_func
self._normal_label = normal_label
self._outlier_label = outlier_label
self._missing_label = missing_label
_AggModelIdMixin.__init__(self, agg_model_id)
_SourcePredictionMixin.__init__(self, include_source_predictions)
Expand Down Expand Up @@ -208,10 +215,8 @@ class MajorityVote(LabelAggregation):
**kwargs: Additional keyword arguments to pass to the base
`LabelAggregation` class.
"""
def __init__(self, normal_label=0, outlier_label=1, tie_breaker=0, **kwargs):
def __init__(self, tie_breaker=DEFAULT_NORMAL_LABEL, **kwargs):
self._tie_breaker = tie_breaker
self._normal_label = normal_label
self._outlier_label = outlier_label

def inner(predictions: Iterable[int]) -> int:
counters = collections.Counter(predictions)
Expand Down Expand Up @@ -248,10 +253,7 @@ class AllVote(LabelAggregation):
**kwargs: Additional keyword arguments to pass to the base
`LabelAggregation` class.
"""
def __init__(self, normal_label=0, outlier_label=1, **kwargs):
self._normal_label = normal_label
self._outlier_label = outlier_label

def __init__(self, **kwargs):
def inner(predictions: Iterable[int]) -> int:
return self._outlier_label if all(
map(lambda p: p == self._outlier_label,
Expand Down Expand Up @@ -282,10 +284,7 @@ class AnyVote(LabelAggregation):
**kwargs: Additional keyword arguments to pass to the base
`LabelAggregation` class.
"""
def __init__(self, normal_label=0, outlier_label=1, **kwargs):
self._normal_label = normal_label
self._outlier_label = outlier_label

def __init__(self, **kwargs):
def inner(predictions: Iterable[int]) -> int:
return self._outlier_label if any(
map(lambda p: p == self._outlier_label,
Expand Down
10 changes: 7 additions & 3 deletions sdks/python/apache_beam/ml/anomaly/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
"EnsembleAnomalyDetector"
]

DEFAULT_NORMAL_LABEL = 0
DEFAULT_OUTLIER_LABEL = 1
DEFAULT_MISSING_LABEL = -2


@dataclass(frozen=True)
class AnomalyPrediction():
Expand Down Expand Up @@ -79,9 +83,9 @@ class ThresholdFn(abc.ABC):
"""
def __init__(
self,
normal_label: int = 0,
outlier_label: int = 1,
missing_label: int = -2):
normal_label: int = DEFAULT_NORMAL_LABEL,
outlier_label: int = DEFAULT_OUTLIER_LABEL,
missing_label: int = DEFAULT_MISSING_LABEL):
self._normal_label = normal_label
self._outlier_label = outlier_label
self._missing_label = missing_label
Expand Down
25 changes: 22 additions & 3 deletions sdks/python/apache_beam/ml/anomaly/specifiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,15 @@ def _from_spec_helper(v, _run_init):

@classmethod
def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self:
"""Generate a `Specifiable` subclass object based on a spec."""
"""Generate a `Specifiable` subclass object based on a spec.

Args:
spec: the specification of a `Specifiable` subclass object
_run_init: whether to call `__init__` or not for the initial instantiation

Returns:
Self: the `Specifiable` subclass object
"""
if spec.type is None:
raise ValueError(f"Spec type not found in {spec}")

Expand Down Expand Up @@ -153,7 +161,11 @@ def _to_spec_helper(v):
return v

def to_spec(self) -> Spec:
"""Generate a spec from a `Specifiable` subclass object."""
"""Generate a spec from a `Specifiable` subclass object.

Returns:
Spec: The specification of the instance.
"""
if getattr(type(self), 'spec_type', None) is None:
raise ValueError(
f"'{type(self).__name__}' not registered as Specifiable. "
Expand Down Expand Up @@ -262,7 +274,14 @@ def new_init(self: Specifiable, *args, **kwargs):
original_init(self, *args, **kwargs)
self._initialized = True

def run_original_init(self):
def run_original_init(self) -> None:
"""Execute the original `__init__` method with its saved arguments.

For instances of the `Specifiable` class, initialization is deferred
(lazy initialization). This function forces the execution of the
original `__init__` method using the arguments captured during
the object's initial instantiation.
"""
self._in_init = True
original_init(self, **self.init_kwargs)
self._in_init = False
Expand Down
159 changes: 0 additions & 159 deletions sdks/python/apache_beam/ml/anomaly/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,172 +17,13 @@

from __future__ import annotations

import dataclasses
import math
from typing import Any
from typing import Iterable
from typing import Optional
from typing import Tuple
from typing import Union

import apache_beam as beam
from apache_beam.coders import DillCoder
from apache_beam.ml.anomaly.base import AnomalyResult
from apache_beam.ml.anomaly.base import ThresholdFn
from apache_beam.ml.anomaly.specifiable import Spec
from apache_beam.ml.anomaly.specifiable import Specifiable
from apache_beam.ml.anomaly.specifiable import specifiable
from apache_beam.ml.anomaly.univariate.quantile import BufferedSlidingQuantileTracker # pylint: disable=line-too-long
from apache_beam.ml.anomaly.univariate.quantile import QuantileTracker
from apache_beam.transforms.userstate import ReadModifyWriteRuntimeState
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec


class BaseThresholdDoFn(beam.DoFn):
"""Applies a ThresholdFn to anomaly detection results.

This abstract base class defines the structure for DoFns that use a
`ThresholdFn` to convert anomaly scores into anomaly labels (e.g., normal
or outlier). It handles the core logic of applying the threshold function
and updating the prediction labels within `AnomalyResult` objects.

Args:
threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be
used.
"""
def __init__(self, threshold_fn_spec: Spec):
self._threshold_fn_spec = threshold_fn_spec
self._threshold_fn = None

def _apply_threshold_to_predictions(
self, result: AnomalyResult) -> AnomalyResult:
"""Updates the prediction labels in an AnomalyResult using the ThresholdFn.

Args:
result (AnomalyResult): The input `AnomalyResult` containing anomaly
scores.

Returns:
AnomalyResult: A new `AnomalyResult` with updated prediction labels
and threshold values.
"""
predictions = [
dataclasses.replace(
p,
label=self._threshold_fn.apply(p.score),
threshold=self._threshold_fn.threshold) for p in result.predictions
]
return dataclasses.replace(result, predictions=predictions)


class StatelessThresholdDoFn(BaseThresholdDoFn):
"""Applies a stateless ThresholdFn to anomaly detection results.

This DoFn is designed for stateless `ThresholdFn` implementations. It
initializes the `ThresholdFn` once during setup and applies it to each
incoming element without maintaining any state across elements.

Args:
threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be
used.

Raises:
AssertionError: If the provided `threshold_fn_spec` leads to the
creation of a stateful `ThresholdFn`.
"""
def __init__(self, threshold_fn_spec: Spec):
threshold_fn_spec.config["_run_init"] = True
self._threshold_fn: Any = Specifiable.from_spec(threshold_fn_spec)
assert not self._threshold_fn.is_stateful, \
"This DoFn can only take stateless function as threshold_fn"

def process(self, element: Tuple[Any, Tuple[Any, AnomalyResult]],
**kwargs) -> Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]:
"""Processes a batch of anomaly results using a stateless ThresholdFn.

Args:
element (Tuple[Any, Tuple[Any, AnomalyResult]]): A tuple representing
an element in the Beam pipeline. It is expected to be in the format
`(key1, (key2, AnomalyResult))`, where key1 is the original input key,
and key2 is a disambiguating key for distinct data points.
**kwargs: Additional keyword arguments passed to the `process` method
in Beam DoFns.

Yields:
Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: An iterable containing
a single output element with the same structure as the input, but with
the `AnomalyResult` having updated prediction labels based on the
stateless `ThresholdFn`.
"""
k1, (k2, result) = element
yield k1, (k2, self._apply_threshold_to_predictions(result))


class StatefulThresholdDoFn(BaseThresholdDoFn):
"""Applies a stateful ThresholdFn to anomaly detection results.

This DoFn is designed for stateful `ThresholdFn` implementations. It leverages
Beam's state management to persist and update the state of the `ThresholdFn`
across multiple elements. This is necessary for `ThresholdFn`s that need to
accumulate information or adapt over time, such as quantile-based thresholds.

Args:
threshold_fn_spec (Spec): Specification defining the `ThresholdFn` to be
used.

Raises:
AssertionError: If the provided `threshold_fn_spec` leads to the
creation of a stateless `ThresholdFn`.
"""
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder())

def __init__(self, threshold_fn_spec: Spec):
threshold_fn_spec.config["_run_init"] = True
threshold_fn: Any = Specifiable.from_spec(threshold_fn_spec)
assert threshold_fn.is_stateful, \
"This DoFn can only take stateful function as threshold_fn"
self._threshold_fn_spec = threshold_fn_spec

def process(
self,
element: Tuple[Any, Tuple[Any, AnomalyResult]],
threshold_state: Union[ReadModifyWriteRuntimeState,
Any] = beam.DoFn.StateParam(THRESHOLD_STATE_INDEX),
**kwargs) -> Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]:
"""Processes a batch of anomaly results using a stateful ThresholdFn.

For each input element, this DoFn retrieves the stateful `ThresholdFn` from
Beam state, initializes it if it's the first time, applies it to update
the prediction labels in the `AnomalyResult`, and then updates the state in
Beam for future elements.

Args:
element (Tuple[Any, Tuple[Any, AnomalyResult]]): A tuple representing
an element in the Beam pipeline. It is expected to be in the format
`(key1, (key2, AnomalyResult))`, where key1 is the original input key,
and key2 is a disambiguating key for distinct data points.
threshold_state (Union[ReadModifyWriteRuntimeState, Any]): A Beam state
parameter that provides access to the persisted state of the
`ThresholdFn`. It is automatically managed by Beam.
**kwargs: Additional keyword arguments passed to the `process` method
in Beam DoFns.

Yields:
Iterable[Tuple[Any, Tuple[Any, AnomalyResult]]]: An iterable containing
a single output element with the same structure as the input, but
with the `AnomalyResult` having updated prediction labels based on
the stateful `ThresholdFn`.
"""
k1, (k2, result) = element

self._threshold_fn = threshold_state.read()
if self._threshold_fn is None:
self._threshold_fn: Specifiable = Specifiable.from_spec(
self._threshold_fn_spec)

yield k1, (k2, self._apply_threshold_to_predictions(result))

threshold_state.write(self._threshold_fn)


@specifiable
Expand Down
Loading
Loading