From b97d4bd1e9c25cb4bd1a4865a2d1d6418d9b1554 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 11 Mar 2025 21:08:33 -0400 Subject: [PATCH 1/2] Support functions and classes as init arguments in specifiable. --- .../apache_beam/ml/anomaly/specifiable.py | 68 ++++++++--- .../ml/anomaly/specifiable_test.py | 111 +++++++++++++++++- 2 files changed, 160 insertions(+), 19 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index e73ef5513b64..c316c0c1b9a0 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -25,9 +25,11 @@ import dataclasses import inspect import logging +import os from typing import Any from typing import ClassVar from typing import List +from typing import Optional from typing import Protocol from typing import Type from typing import TypeVar @@ -65,9 +67,11 @@ def _class_to_subspace(cls: Type) -> str: subspace list. This is usually called when registering a new specifiable class. """ - for c in cls.mro(): - if c.__name__ in _ACCEPTED_SUBSPACES: - return c.__name__ + if hasattr(cls, "mro"): + # some classes do not have "mro", such as functions. + for c in cls.mro(): + if c.__name__ in _ACCEPTED_SUBSPACES: + return c.__name__ return _FALLBACK_SUBSPACE @@ -92,8 +96,10 @@ class Spec(): """ #: A string indicating the concrete `Specifiable` class type: str - #: A dictionary of keyword arguments for the `__init__` method of the class. - config: dict[str, Any] = dataclasses.field(default_factory=dict) + #: An optional dictionary of keyword arguments for the `__init__` method of + #: the class. If None, when we materialize this Spec, we only return the + #: class without instantiate any objects from it. + config: Optional[dict[str, Any]] = dataclasses.field(default_factory=dict) @runtime_checkable @@ -137,9 +143,15 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: subspace = _spec_type_to_subspace(spec.type) subclass: Type[Self] = _KNOWN_SPECIFIABLE[subspace].get(spec.type, None) + if subclass is None: raise ValueError(f"Unknown spec type '{spec.type}' in {spec}") + if spec.config is None: + # when functions or classes are used as arguments, we won't try to + # create an instance. + return subclass + kwargs = { k: Specifiable._from_spec_helper(v, _run_init) for k, @@ -158,6 +170,16 @@ def _to_spec_helper(v): if isinstance(v, List): return [Specifiable._to_spec_helper(e) for e in v] + if inspect.isfunction(v): + if not hasattr(v, "spec_type"): + _register(v, inject_spec_type=False) + return Spec(type=_get_default_spec_type(v), config=None) + + if inspect.isclass(v): + if not hasattr(v, "spec_type"): + _register(v, inject_spec_type=False) + return Spec(type=_get_default_spec_type(v), config=None) + return v def to_spec(self) -> Spec: @@ -180,23 +202,40 @@ def run_original_init(self) -> None: pass +def _get_default_spec_type(cls): + spec_type = cls.__name__ + if inspect.isfunction(cls) and cls.__name__ == "": + # for lambda functions, we need to include more information to distinguish + # among them + spec_type = '' % ( + os.path.basename(cls.__code__.co_filename), cls.__code__.co_firstlineno) + + return spec_type + + # Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE` -def _register(cls, spec_type=None) -> None: +def _register(cls, spec_type=None, inject_spec_type=True) -> None: + assert spec_type is None or inject_spec_type, \ + "need to inject spec_type to class if spec_type is not None" if spec_type is None: - # By default, spec type is the class name. Users can override this with - # other unique identifier. - spec_type = cls.__name__ + # Use default spec_type for a class if users do not specify one. + spec_type = _get_default_spec_type(cls) subspace = _class_to_subspace(cls) if spec_type in _KNOWN_SPECIFIABLE[subspace]: - raise ValueError( - f"{spec_type} is already registered for " - f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. " - "Please specify a different spec_type by @specifiable(spec_type=...).") + if cls is not _KNOWN_SPECIFIABLE[subspace][spec_type]: + # only raise exception if we register the same spec type with a different + # class + raise ValueError( + f"{spec_type} is already registered for " + f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. " + "Please specify a different spec_type by @specifiable(spec_type=...)." + ) else: _KNOWN_SPECIFIABLE[subspace][spec_type] = cls - cls.spec_type = spec_type + if inject_spec_type: + cls.spec_type = spec_type # Keep a copy of arguments that are used to call the `__init__` method when the @@ -331,6 +370,7 @@ def new_getattr(self, name): cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper) cls.from_spec = Specifiable.from_spec cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper) + return cls # end of the function body of _wrapper diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py index a3133f32e996..4c1a7bdaf32a 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py @@ -18,6 +18,7 @@ import copy import dataclasses import logging +import os import unittest from typing import List from typing import Optional @@ -43,6 +44,9 @@ def test_decorator_in_function_form(self): class A(): pass + class B(): + pass + # class is not decorated and thus not registered self.assertNotIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) @@ -53,8 +57,11 @@ class A(): self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A) - # an error is raised if the specified spec_type already exists. - self.assertRaises(ValueError, specifiable, A) + # Re-registering spec_type with the same class is allowed + A = specifiable(A) + + # Raise an error when re-registering spec_type with a different class + self.assertRaises(ValueError, specifiable(spec_type='A'), B) # apply the decorator function to an existing class with a different # spec_type @@ -64,9 +71,6 @@ class A(): self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]) self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A) - # an error is raised if the specified spec_type already exists. - self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A) - def test_decorator_in_syntactic_sugar_form(self): # call decorator without parameters @specifiable @@ -484,6 +488,103 @@ def test_error_in_child(self): self.assertEqual(Child_2.counter, 0) +def my_normal_func(x, y): + return x + y + + +@specifiable +class Wrapper(): + def __init__(self, func=None, cls=None, **kwargs): + self._func = func + if cls is not None: + self._cls = cls(**kwargs) + + def run_func(self, x, y): + return self._func(x, y) + + def run_func_in_class(self, x, y): + return self._cls.apply(x, y) + + +class TestFunctionAsArgument(unittest.TestCase): + def setUp(self) -> None: + self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE) + + def tearDown(self) -> None: + _KNOWN_SPECIFIABLE.clear() + _KNOWN_SPECIFIABLE.update(self.saved_specifiable) + + def test_normal_function(self): + w = Wrapper(my_normal_func) + + self.assertEqual(w.run_func(1, 2), 3) + + w_spec = w.to_spec() + self.assertEqual( + w_spec, + Spec( + type='Wrapper', + config={'func': Spec(type="my_normal_func", config=None)})) + + w_2 = Specifiable.from_spec(w_spec) + self.assertEqual(w_2.run_func(2, 3), 5) + + def test_lambda_function(self): + my_lambda_func = lambda x, y: x - y + + w = Wrapper(my_lambda_func) + + self.assertEqual(w.run_func(3, 2), 1) + + w_spec = w.to_spec() + self.assertEqual( + w_spec, + Spec( + type='Wrapper', + config={ + 'func': Spec( + type= + f"", # pylint: disable=line-too-long + config=None) + } + )) + + w_2 = Specifiable.from_spec(w_spec) + self.assertEqual(w_2.run_func(5, 3), 2) + + +class TestClassAsArgument(unittest.TestCase): + def setUp(self) -> None: + self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE) + + def tearDown(self) -> None: + _KNOWN_SPECIFIABLE.clear() + _KNOWN_SPECIFIABLE.update(self.saved_specifiable) + + def test_normal_class(self): + class InnerClass(): + def __init__(self, multiplier): + self._multiplier = multiplier + + def apply(self, x, y): + return x * y * self._multiplier + + w = Wrapper(cls=InnerClass, multiplier=10) + self.assertEqual(w.run_func_in_class(2, 3), 60) + + w_spec = w.to_spec() + self.assertEqual( + w_spec, + Spec( + type='Wrapper', + config={ + 'cls': Spec(type='InnerClass', config=None), 'multiplier': 10 + })) + + w_2 = Specifiable.from_spec(w_spec) + self.assertEqual(w_2.run_func_in_class(5, 3), 150) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() From 8b0d13571f2b85726ec9524dd2133288e4d15e67 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Thu, 13 Mar 2025 11:28:27 -0400 Subject: [PATCH 2/2] Fix lints. --- sdks/python/apache_beam/ml/anomaly/specifiable.py | 6 ++++-- sdks/python/apache_beam/ml/anomaly/transforms.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py b/sdks/python/apache_beam/ml/anomaly/specifiable.py index c316c0c1b9a0..2eeb1d0de76d 100644 --- a/sdks/python/apache_beam/ml/anomaly/specifiable.py +++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py @@ -28,11 +28,13 @@ import os from typing import Any from typing import ClassVar +from typing import Dict from typing import List from typing import Optional from typing import Protocol from typing import Type from typing import TypeVar +from typing import Union from typing import runtime_checkable from typing_extensions import Self @@ -99,7 +101,7 @@ class Spec(): #: An optional dictionary of keyword arguments for the `__init__` method of #: the class. If None, when we materialize this Spec, we only return the #: class without instantiate any objects from it. - config: Optional[dict[str, Any]] = dataclasses.field(default_factory=dict) + config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict) @runtime_checkable @@ -128,7 +130,7 @@ def _from_spec_helper(v, _run_init): return v @classmethod - def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self: + def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]: """Generate a `Specifiable` subclass object based on a spec. Args: diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py b/sdks/python/apache_beam/ml/anomaly/transforms.py index 35cae18a7224..08b656072ac8 100644 --- a/sdks/python/apache_beam/ml/anomaly/transforms.py +++ b/sdks/python/apache_beam/ml/anomaly/transforms.py @@ -18,6 +18,7 @@ import dataclasses import uuid from typing import Callable +from typing import Dict from typing import Iterable from typing import Optional from typing import Tuple @@ -55,6 +56,8 @@ class _ScoreAndLearnDoFn(beam.DoFn): def __init__(self, detector_spec: Spec): self._detector_spec = detector_spec + + assert isinstance(self._detector_spec.config, Dict) self._detector_spec.config["_run_init"] = True def score_and_learn(self, data): @@ -172,8 +175,10 @@ class _StatelessThresholdDoFn(_BaseThresholdDoFn): creation of a stateful `ThresholdFn`. """ def __init__(self, threshold_fn_spec: Spec): + assert isinstance(threshold_fn_spec.config, Dict) threshold_fn_spec.config["_run_init"] = True self._threshold_fn = Specifiable.from_spec(threshold_fn_spec) + assert isinstance(self._threshold_fn, ThresholdFn) assert not self._threshold_fn.is_stateful, \ "This DoFn can only take stateless function as threshold_fn" @@ -217,8 +222,10 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn): THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder()) def __init__(self, threshold_fn_spec: Spec): + assert isinstance(threshold_fn_spec.config, Dict) threshold_fn_spec.config["_run_init"] = True threshold_fn = Specifiable.from_spec(threshold_fn_spec) + assert isinstance(threshold_fn, ThresholdFn) assert threshold_fn.is_stateful, \ "This DoFn can only take stateful function as threshold_fn" self._threshold_fn_spec = threshold_fn_spec