Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions sdks/python/apache_beam/ml/anomaly/specifiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
import dataclasses
import inspect
import logging
import os
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Type
from typing import TypeVar
from typing import Union
from typing import runtime_checkable

from typing_extensions import Self
Expand Down Expand Up @@ -65,9 +69,11 @@ def _class_to_subspace(cls: Type) -> str:
subspace list. This is usually called when registering a new specifiable
class.
"""
for c in cls.mro():
if c.__name__ in _ACCEPTED_SUBSPACES:
return c.__name__
if hasattr(cls, "mro"):
# some classes do not have "mro", such as functions.
for c in cls.mro():
if c.__name__ in _ACCEPTED_SUBSPACES:
return c.__name__

return _FALLBACK_SUBSPACE

Expand All @@ -92,8 +98,10 @@ class Spec():
"""
#: A string indicating the concrete `Specifiable` class
type: str
#: A dictionary of keyword arguments for the `__init__` method of the class.
config: dict[str, Any] = dataclasses.field(default_factory=dict)
#: An optional dictionary of keyword arguments for the `__init__` method of
#: the class. If None, when we materialize this Spec, we only return the
#: class without instantiate any objects from it.
config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)


@runtime_checkable
Expand Down Expand Up @@ -122,7 +130,7 @@ def _from_spec_helper(v, _run_init):
return v

@classmethod
def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self:
def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
"""Generate a `Specifiable` subclass object based on a spec.

Args:
Expand All @@ -137,9 +145,15 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self:

subspace = _spec_type_to_subspace(spec.type)
subclass: Type[Self] = _KNOWN_SPECIFIABLE[subspace].get(spec.type, None)

if subclass is None:
raise ValueError(f"Unknown spec type '{spec.type}' in {spec}")

if spec.config is None:
# when functions or classes are used as arguments, we won't try to
# create an instance.
return subclass

kwargs = {
k: Specifiable._from_spec_helper(v, _run_init)
for k,
Expand All @@ -158,6 +172,16 @@ def _to_spec_helper(v):
if isinstance(v, List):
return [Specifiable._to_spec_helper(e) for e in v]

if inspect.isfunction(v):
if not hasattr(v, "spec_type"):
_register(v, inject_spec_type=False)
return Spec(type=_get_default_spec_type(v), config=None)

if inspect.isclass(v):
if not hasattr(v, "spec_type"):
_register(v, inject_spec_type=False)
return Spec(type=_get_default_spec_type(v), config=None)

return v

def to_spec(self) -> Spec:
Expand All @@ -180,23 +204,40 @@ def run_original_init(self) -> None:
pass


def _get_default_spec_type(cls):
spec_type = cls.__name__
if inspect.isfunction(cls) and cls.__name__ == "<lambda>":
# for lambda functions, we need to include more information to distinguish
# among them
spec_type = '<lambda at %s:%s>' % (
os.path.basename(cls.__code__.co_filename), cls.__code__.co_firstlineno)

return spec_type


# Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE`
def _register(cls, spec_type=None) -> None:
def _register(cls, spec_type=None, inject_spec_type=True) -> None:
assert spec_type is None or inject_spec_type, \
"need to inject spec_type to class if spec_type is not None"
if spec_type is None:
# By default, spec type is the class name. Users can override this with
# other unique identifier.
spec_type = cls.__name__
# Use default spec_type for a class if users do not specify one.
spec_type = _get_default_spec_type(cls)

subspace = _class_to_subspace(cls)
if spec_type in _KNOWN_SPECIFIABLE[subspace]:
raise ValueError(
f"{spec_type} is already registered for "
f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. "
"Please specify a different spec_type by @specifiable(spec_type=...).")
if cls is not _KNOWN_SPECIFIABLE[subspace][spec_type]:
# only raise exception if we register the same spec type with a different
# class
raise ValueError(
f"{spec_type} is already registered for "
f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. "
"Please specify a different spec_type by @specifiable(spec_type=...)."
)
else:
_KNOWN_SPECIFIABLE[subspace][spec_type] = cls

cls.spec_type = spec_type
if inject_spec_type:
cls.spec_type = spec_type


# Keep a copy of arguments that are used to call the `__init__` method when the
Expand Down Expand Up @@ -331,6 +372,7 @@ def new_getattr(self, name):
cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper)
cls.from_spec = Specifiable.from_spec
cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper)

return cls
# end of the function body of _wrapper

Expand Down
111 changes: 106 additions & 5 deletions sdks/python/apache_beam/ml/anomaly/specifiable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import dataclasses
import logging
import os
import unittest
from typing import List
from typing import Optional
Expand All @@ -43,6 +44,9 @@ def test_decorator_in_function_form(self):
class A():
pass

class B():
pass

# class is not decorated and thus not registered
self.assertNotIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])

Expand All @@ -53,8 +57,11 @@ class A():
self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A)

# an error is raised if the specified spec_type already exists.
self.assertRaises(ValueError, specifiable, A)
# Re-registering spec_type with the same class is allowed
A = specifiable(A)

# Raise an error when re-registering spec_type with a different class
self.assertRaises(ValueError, specifiable(spec_type='A'), B)

# apply the decorator function to an existing class with a different
# spec_type
Expand All @@ -64,9 +71,6 @@ class A():
self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A)

# an error is raised if the specified spec_type already exists.
self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A)

def test_decorator_in_syntactic_sugar_form(self):
# call decorator without parameters
@specifiable
Expand Down Expand Up @@ -484,6 +488,103 @@ def test_error_in_child(self):
self.assertEqual(Child_2.counter, 0)


def my_normal_func(x, y):
return x + y


@specifiable
class Wrapper():
def __init__(self, func=None, cls=None, **kwargs):
self._func = func
if cls is not None:
self._cls = cls(**kwargs)

def run_func(self, x, y):
return self._func(x, y)

def run_func_in_class(self, x, y):
return self._cls.apply(x, y)


class TestFunctionAsArgument(unittest.TestCase):
def setUp(self) -> None:
self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)

def tearDown(self) -> None:
_KNOWN_SPECIFIABLE.clear()
_KNOWN_SPECIFIABLE.update(self.saved_specifiable)

def test_normal_function(self):
w = Wrapper(my_normal_func)

self.assertEqual(w.run_func(1, 2), 3)

w_spec = w.to_spec()
self.assertEqual(
w_spec,
Spec(
type='Wrapper',
config={'func': Spec(type="my_normal_func", config=None)}))

w_2 = Specifiable.from_spec(w_spec)
self.assertEqual(w_2.run_func(2, 3), 5)

def test_lambda_function(self):
my_lambda_func = lambda x, y: x - y

w = Wrapper(my_lambda_func)

self.assertEqual(w.run_func(3, 2), 1)

w_spec = w.to_spec()
self.assertEqual(
w_spec,
Spec(
type='Wrapper',
config={
'func': Spec(
type=
f"<lambda at {os.path.basename(__file__)}:{my_lambda_func.__code__.co_firstlineno}>", # pylint: disable=line-too-long
config=None)
}
))

w_2 = Specifiable.from_spec(w_spec)
self.assertEqual(w_2.run_func(5, 3), 2)


class TestClassAsArgument(unittest.TestCase):
def setUp(self) -> None:
self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)

def tearDown(self) -> None:
_KNOWN_SPECIFIABLE.clear()
_KNOWN_SPECIFIABLE.update(self.saved_specifiable)

def test_normal_class(self):
class InnerClass():
def __init__(self, multiplier):
self._multiplier = multiplier

def apply(self, x, y):
return x * y * self._multiplier

w = Wrapper(cls=InnerClass, multiplier=10)
self.assertEqual(w.run_func_in_class(2, 3), 60)

w_spec = w.to_spec()
self.assertEqual(
w_spec,
Spec(
type='Wrapper',
config={
'cls': Spec(type='InnerClass', config=None), 'multiplier': 10
}))

w_2 = Specifiable.from_spec(w_spec)
self.assertEqual(w_2.run_func_in_class(5, 3), 150)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/ml/anomaly/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import uuid
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -55,6 +56,8 @@ class _ScoreAndLearnDoFn(beam.DoFn):

def __init__(self, detector_spec: Spec):
self._detector_spec = detector_spec

assert isinstance(self._detector_spec.config, Dict)
self._detector_spec.config["_run_init"] = True

def score_and_learn(self, data):
Expand Down Expand Up @@ -172,8 +175,10 @@ class _StatelessThresholdDoFn(_BaseThresholdDoFn):
creation of a stateful `ThresholdFn`.
"""
def __init__(self, threshold_fn_spec: Spec):
assert isinstance(threshold_fn_spec.config, Dict)
threshold_fn_spec.config["_run_init"] = True
self._threshold_fn = Specifiable.from_spec(threshold_fn_spec)
assert isinstance(self._threshold_fn, ThresholdFn)
assert not self._threshold_fn.is_stateful, \
"This DoFn can only take stateless function as threshold_fn"

Expand Down Expand Up @@ -217,8 +222,10 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn):
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder())

def __init__(self, threshold_fn_spec: Spec):
assert isinstance(threshold_fn_spec.config, Dict)
threshold_fn_spec.config["_run_init"] = True
threshold_fn = Specifiable.from_spec(threshold_fn_spec)
assert isinstance(threshold_fn, ThresholdFn)
assert threshold_fn.is_stateful, \
"This DoFn can only take stateful function as threshold_fn"
self._threshold_fn_spec = threshold_fn_spec
Expand Down
Loading